Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tsingmicro-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
- name: FlagTree Build on Tsingmicro
shell: bash
run: |
pip uninstall -y triton
source ~/env.sh
export FLAGTREE_BACKEND=tsingmicro
cd python
Expand All @@ -59,4 +60,3 @@ jobs:
source ~/env.sh
python3.11 -c 'import triton; print(triton.__path__)'
/usr/local/lib/python3.11/dist-packages/triton/backends/tsingmicro/bin/tsingmicro-opt --version
/usr/local/lib/python3.11/dist-packages/triton/backends/tsingmicro/bin/tsingmicro-llvm-opt --version
4 changes: 2 additions & 2 deletions python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,10 @@ def check_env(env_val):

# tsingmicro
cache.store(
file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64",
file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.11-x64",
condition=("tsingmicro" == flagtree_backend),
url=
"https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64.tar.gz",
"https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.11-x64.tar.gz",
pre_hock=lambda: check_env('LLVM_SYSPATH'),
post_hock=set_llvm_env,
)
15 changes: 8 additions & 7 deletions third_party/tsingmicro/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
if(NOT DEFINED TX8_HOME)
if(DEFINED ENV{TX8_HOME})
set(TX8_HOME $ENV{TX8_HOME})
if(NOT DEFINED TX8_DEPS_ROOT)
if(DEFINED ENV{TX8_DEPS_ROOT})
set(TX8_DEPS_ROOT $ENV{TX8_DEPS_ROOT})
else()
message(FATAL_ERROR "TX8_HOME environment variable is not defined")
message(FATAL_ERROR "TX8_DEPS_ROOT environment variable is not defined")
endif()
endif()

set(TSM_BACKEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/backend)
set(XUANTIE_NAME Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2)
set(INSTALL_TSINGMICRO_DIR ${CMAKE_INSTALL_PREFIX}/triton/backends/tsingmicro/)
install(CODE "file(MAKE_DIRECTORY \"${INSTALL_TSINGMICRO_DIR}\")")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/crt/include)
include_directories(${TX8_HOME}/include)
include_directories(${TX8_DEPS_ROOT}/include)
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(bin)
Expand All @@ -23,7 +24,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR
Tx81IR TritonTilingExtIR TritonStructuredIR TritonToCoreDialects
TritonToLinalg TritonToStructured StructuredToMemref LinalgToMagicKernel
TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81)
TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81 LLVMRISCVCodeGen LLVMRISCVAsmParser)
target_link_libraries(TritonTsingMicro PRIVATE Python3::Module pybind11::headers)
endif()
#if(TRITON_BUILD_UT)
Expand Down
109 changes: 68 additions & 41 deletions third_party/tsingmicro/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def _get_llvm_bin_path(bin_name: str) -> str:
return os.path.join(path, "bin", bin_name)


def _get_tx8_path(sub_name: str) -> str:
path = os.getenv("TX8_HOME", "")
def _get_tx8_deps_path(sub_name: str) -> str:
path = os.getenv("TX8_DEPS_ROOT", "")
if path == "":
raise Exception("TX8_HOME is not set.")
raise Exception("TX8_DEPS_ROOT is not set.")
return os.path.join(path, sub_name)


Expand All @@ -55,22 +55,36 @@ def compile_accelerator():
# FIXME: Hardcoded path
#dst_path = os.path.join(tmpdir, f"{name}.so")
dst_path = "/tmp/kernel.so"
libc_lib = os.path.join(_get_tx8_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf",
libc_lib = os.path.join(_get_tx8_deps_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf",
"lib", "rv64imafdc", "lp64d")
# libvr_path = os.path.join(os.path.dirname(__file__), "lib")
libvr_path = os.path.join(os.path.dirname(__file__), "lib")
clang_path = _get_llvm_bin_path("clang")
lld_path = _get_llvm_bin_path("ld.lld")
tx8_lib = _get_tx8_path("lib")
subprocess.check_call([
clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2",
f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d",
"-Wl,--no-dynamic-linker",
# FIXME: Hardcoded path
"/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive",
"-linstr_tx81", # Tx81 intrinsic API
"-lvr", # Wrapper API of Tx81 intrinsic
"-Wl,--no-whole-archive", "-lm", "-o", dst_path
])

tx8_lib = _get_tx8_deps_path("lib")
# Build shared library for simulator or hardware
if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")):
subprocess.check_call([
clang_path, "-shared", "-O2", f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles",
"-Wl,--allow-shlib-undefined", "-Wl,--no-dynamic-linker",
# FIXME: Hardcoded path
"/tmp/kernel.o", f"-L{libvr_path}", f"-L{tx8_lib}", "-Wl,--whole-archive",
"-lvr", # Wrapper API of Tx81 intrinsic
"-ltriton_cmodel", "-ltx8be_op_cmodel", "-Wl,--no-whole-archive", "-lm", "-o", dst_path
])
else:
# Link wrapper, kernel with Tx81 crt and intrinsics(libinstr_tx81.a)
subprocess.check_call([
clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2",
f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d",
"-Wl,--no-dynamic-linker",
# FIXME: Hardcoded path
"/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive",
"-linstr_tx81", # Tx81 intrinsic API
"-lvr", # Wrapper API of Tx81 intrinsic
"-Wl,--no-whole-archive", "-lm", "-o", dst_path
])

_dump_ir_if_needed([dst_path])
with open(dst_path, 'rb') as f:
Expand All @@ -85,10 +99,11 @@ def _ttir_to_coreir(mod):
src_path = os.path.join(tmpdir, "tt.mlir")
dst_path = os.path.join(tmpdir, "core.mlir")
Path(src_path).write_text(ttir_code)
tsm_opt_path = _get_tsm_opt_path()
triton_opt_path = _get_tsm_opt_path()
_dump_ir_if_needed([src_path])
subprocess.check_call([
tsm_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize=allow-return-allocs-from-loops",
triton_opt_path, src_path, "--triton-to-core-dialects", "--core-dialects-to-mk",
"--one-shot-bufferize=allow-return-allocs-from-loops",
#"--mlir-print-debuginfo",
"-o", dst_path
])
Expand All @@ -107,10 +122,10 @@ def _coreir_to_mkir(mod):
src_path = os.path.join(tmpdir, "core.mlir")
dst_path = os.path.join(tmpdir, "mk.mlir")
Path(src_path).write_text(coreir_code)
tsm_opt_path = _get_tsm_opt_path()
triton_opt_path = _get_tsm_opt_path()
_dump_ir_if_needed([src_path])
subprocess.check_call([
tsm_opt_path, src_path, "--core-dialects-to-mk",
triton_opt_path, src_path, "--core-dialects-to-mk",
#"--mlir-print-debuginfo",
"-o", dst_path
])
Expand All @@ -129,10 +144,10 @@ def _coreir_to_txir(mod):
src_path = os.path.join(tmpdir, "core.mlir")
dst_path = os.path.join(tmpdir, "tx.mlir")
Path(src_path).write_text(coreir_code)
tsm_opt_path = _get_tsm_opt_path()
triton_opt_path = _get_tsm_opt_path()
_dump_ir_if_needed([src_path])
subprocess.check_call([
tsm_opt_path, src_path, "--expand-strided-metadata",
triton_opt_path, src_path, "--expand-strided-metadata",
"--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load
"--mk-to-tx81", "--cse", # unused memref.subview/memref.reinterpret
#"--mlir-print-debuginfo",
Expand All @@ -153,20 +168,23 @@ def _txir_to_llir(mod, metadata):
llvmir_path = os.path.join(tmpdir, "ll.mlir")
llir_path = os.path.join(tmpdir, "ll.ir")
Path(src_path).write_text(txir_code)
tsm_opt_path = _get_tsm_opt_path()
triton_opt_path = _get_tsm_opt_path()
_dump_ir_if_needed([src_path])
# Tx81 and core dialects to LLVM-MLIR
args = [
tsm_opt_path, src_path,
triton_opt_path, src_path,
# Use tx81-memref-to-llvm to replace "--finalize-memref-to-llvm".
"--tx81-memref-to-llvm", "--convert-scf-to-cf", "--convert-math-to-llvm",
"--convert-cf-to-llvm", # need exec before "convert-func-to-llvm"
"--tx81-memref-to-llvm", "--convert-scf-to-cf", "--test-math-polynomial-approximation",
"--convert-math-to-llvm", "--convert-cf-to-llvm", # need exec before "convert-func-to-llvm"
"--convert-func-to-llvm", # need exec before "kernel-arg-buffer", otherwise un-rank memref will translate to int(rank) + ptr
# Other unconverted memref ops, eg: memref.global from scan op conversion
"--finalize-memref-to-llvm"
]

args.append(
"--kernel-arg-buffer"
) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel
# WORKAROUND: To replace function signature to "kernel(ptr)"
if os.getenv("VENDOR_VERSION", "") != "":
args.append(
"--kernel-arg-buffer"
) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel

# other pass
args += [
Expand All @@ -181,14 +199,12 @@ def _txir_to_llir(mod, metadata):

_dump_ir_if_needed([llvmir_path])

llvm_file = os.getenv("CUSTOMIZED_IR", "")
if (llvm_file != ""):
llvmir_path = os.getenv("TRITON_DUMP_PATH", "")

if not llvmir_path:
return

llvmir_path = os.path.join(llvmir_path, llvm_file)
# Get spm memory use metadata
from mlir.ir import Context, Module
with Context() as ctx:
llvmir_str = Path(llvmir_path).read_text()
llvmir_module = Module.parse(llvmir_str)
metadata["shared"] = llvmir_module.operation.attributes["triton_tsm.spm_use"].value

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
Expand Down Expand Up @@ -244,17 +260,23 @@ def _llir_to_bin(llir: str, metadata):
matches = re.findall(pattern, llir)
assert len(matches) == 1
metadata["name"] = matches[0]
# Build kernel for simulator and hardware
sim_mode = os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "kernel.ll")
# FIXME: Hardcoded path
#dst_path = os.path.join(tmpdir, "kernel.so")
dst_path = "/tmp/kernel.o"
Path(src_path).write_text(llir)
clang_path = _get_llvm_bin_path("clang++")
subprocess.check_call([
clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-o",
dst_path
])

compile_args = [clang_path, src_path, "-O2", "-c", "-fPIC", "-o", dst_path]

# Add RISC-V specific flags when not in simulation mode
if not sim_mode:
compile_args.extend(["--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc"])

subprocess.check_call(compile_args)

_dump_ir_if_needed([dst_path])

Expand All @@ -269,6 +291,10 @@ class TXDAOptions:
num_warps: int = 0
num_ctas: int = 0
num_stages: int = 1
num_buffers_warp_spec: int = 0
num_consumer_groups: int = 0
reg_dec_producer: int = 0
reg_inc_consumer: int = 0
enable_warp_specialization: bool = False
enable_fp_fusion: bool = False
extern_libs = None
Expand Down Expand Up @@ -313,6 +339,7 @@ def pack_metadata(self, metadata):
metadata.cluster_dims[1], metadata.cluster_dims[2], metadata.name)

# Our compilation pipeline isn't in python like nvidia or amd, no need to load
# dialects. See `ztc.cc`
def load_dialects(self, ctx):
return

Expand Down
Loading