diff --git a/.github/workflows/tsingmicro-build-and-test.yml b/.github/workflows/tsingmicro-build-and-test.yml new file mode 100644 index 000000000..9323774a9 --- /dev/null +++ b/.github/workflows/tsingmicro-build-and-test.yml @@ -0,0 +1,60 @@ +name: Tsingmicro-Build-And-Test + +on: + push: + branches: [ "triton_v3.3.x" ] + pull_request: + branches: [ "triton_v3.3.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + tsingmicro-build-and-test: + runs-on: tsingmicro + steps: + - name: Checkout code (attempt 1) + id: checkout1 + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before checkout2 + if: steps.checkout1.outcome == 'failure' + run: | + echo "First checkout attempt failed. Sleeping for 120 seconds before retry..." + sleep 120 + + - name: Checkout code (attempt 2) + id: checkout2 + if: steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before final checkout + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + run: | + echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..." + sleep 180 + + - name: Checkout code (final attempt) + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + + - name: Verify checkout success + if: success() + run: echo "Checkout completed successfully" + + - name: FlagTree Build on Tsingmicro + shell: bash + run: | + source ~/env.sh + export FLAGTREE_BACKEND=tsingmicro + cd python + python3.10 -m pip install . --no-build-isolation -v + + - name: FlagTree Test on Tsingmicro + shell: bash + run: | + source ~/env.sh + python3.10 -c 'import triton; print(triton.__path__)' diff --git a/.gitignore b/.gitignore index dd917eb2f..e15c39218 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,8 @@ third_party/cambricon/ third_party/iluvatar/iluvatarTritonPlugin.so third_party/triton_shared/ third_party/xpu/backend/xpu3 +third_party/tsingmicro/backend/lib +third_party/tsingmicro/backend/bin # Proton python/triton/profiler diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e6838c340..9fec66392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,25 +39,6 @@ repos: hooks: - id: clang-format - # Expand YAML anchors in files used by github workflows, because github can't - # do this itself. This lets us use anchors, which avoids code duplication. - - repo: local - hooks: - - id: expand-yaml-anchors - name: Expand YAML anchors - language: golang - additional_dependencies: [github.com/mikefarah/yq/v4@latest] - entry: > - bash -c ' - OUT=".github/workflows/integration-tests.yml" - IN="$OUT.in" - echo "# AUTOGENERATED by pre-commit, modify the .in file instead." > "$OUT" && - echo >> "$OUT" - yq "explode(.)" "$IN" >> "$OUT" - ' - files: ^.github/workflows/integration-tests.yml.* - pass_filenames: false - exclude: | (?x)( ^include/triton/external/| diff --git a/CMakeLists.txt b/CMakeLists.txt index ea64b2752..1476127de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,9 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") elseif(FLAGTREE_BACKEND STREQUAL "aipu") add_definitions(-D__NVIDIA__) add_definitions(-D__AMD__) +elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") + set(CMAKE_C_COMPILER clang) + set(CMAKE_CXX_COMPILER clang++) endif() set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") if(FLAGTREE_PLUGIN) @@ -204,7 +207,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu)$") +if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro)$") include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files add_subdirectory(include) @@ -364,6 +367,12 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMXPUCodeGen LLVMXPUAsmParser ) + elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") + list(APPEND TRITON_LIBRARIES + # riscv + LLVMRISCVCodeGen + LLVMRISCVAsmParser + ) endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 @@ -446,7 +455,7 @@ find_package(Threads REQUIRED) add_subdirectory(third_party/f2reduce) -if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND STREQUAL "aipu") +if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro)$") add_subdirectory(bin) add_subdirectory(test) endif() diff --git a/README.md b/README.md index a5f354c6c..68aa9d348 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# tsingmicro +# Recommended: Use Ubuntu 20.04 +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/ +git checkout -b triton_v3.3.x origin/triton_v3.3.x +export FLAGTREE_BACKEND=tsingmicro +python3 -m pip install . --no-build-isolation -v +``` To build with default backends (nvidia, amd, triton_shared): ```shell diff --git a/README_cn.md b/README_cn.md index e205ad49b..11d9b2358 100644 --- a/README_cn.md +++ b/README_cn.md @@ -51,6 +51,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# tsingmicro +# 推荐使用镜像 Ubuntu 20.04 +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/ +git checkout -b triton_v3.3.x origin/triton_v3.3.x +export FLAGTREE_BACKEND=tsingmicro +python3 -m pip install . --no-build-isolation -v +``` 使用默认的编译命令,可以编译安装 nvidia、amd、triton_shared 后端: ```shell diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 04d24dff8..3329aada1 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -69,8 +69,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::registerAllocateSharedMemoryPass(); mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); - mlir::triton::registerConvertWarpSpecializeToLLVM(); #ifdef __NVIDIA__ + mlir::triton::registerConvertWarpSpecializeToLLVM(); mlir::triton::registerConvertTritonGPUToLLVMPass(); mlir::triton::registerConvertNVGPUToLLVMPass(); #endif diff --git a/python/setup.py b/python/setup.py index c9e623f9b..a84b896fd 100644 --- a/python/setup.py +++ b/python/setup.py @@ -432,6 +432,7 @@ def build_extension(self, ext): thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) + ext_base_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) # create build directories if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -471,6 +472,7 @@ def build_extension(self, ext): "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", + f"-DCMAKE_INSTALL_PREFIX={ext_base_dir}", ] # Note that asan doesn't work with binaries that use the GPU, so this is @@ -512,6 +514,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) + subprocess.check_call(["cmake", "--install", "."], cwd=cmake_dir) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") @@ -597,7 +600,7 @@ def build_extension(self, ext): ) if helper.flagtree_backend: - if helper.flagtree_backend == "aipu": + if helper.flagtree_backend in ("aipu", "tsingmicro"): backends = [ *BackendInstaller.copy(helper.default_backends + helper.extend_backends), *BackendInstaller.copy_externals(), diff --git a/python/setup_helper.py b/python/setup_helper.py index fc99295fb..fc01e7627 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -10,13 +10,13 @@ import hashlib from dataclasses import dataclass +flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() +flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() use_triton_shared = False -necessary_third_party = ["flir"] +necessary_third_party = ["" if flagtree_backend == "tsingmicro" else "flir"] default_backends = ["nvidia", "amd"] extend_backends = [] ext_sourcedir = "triton/_C/" -flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() -flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() @dataclass @@ -39,7 +39,6 @@ class FlagTreeBackend: } set_llvm_env = lambda path: set_env({ - 'LLVM_BUILD_DIR': path, 'LLVM_INCLUDE_DIRS': Path(path) / "include", 'LLVM_LIBRARY_DIR': Path(path) / "lib", 'LLVM_SYSPATH': path, @@ -239,7 +238,7 @@ def skip_package_dir(package): @staticmethod def get_package_dir(packages): package_dict = {} - if flagtree_backend and flagtree_backend not in ("cambricon", "aipu"): + if flagtree_backend and flagtree_backend not in ("cambricon", "aipu", "tsingmicro"): connection = [] backend_triton_path = f"../third_party/{flagtree_backend}/python/" for package in packages: @@ -284,7 +283,8 @@ def git_clone(lib, lib_path): "so we couldn't compile triton_shared\n") third_partys = [] - third_partys.append(flagtree_backend_info["flir"]) + if flagtree_backend != "tsingmicro": + third_partys.append(flagtree_backend_info["flir"]) if os.environ.get("USE_TRITON_SHARED", "ON") == "ON": third_partys.append(flagtree_backend_info["triton_shared"]) else: @@ -305,9 +305,10 @@ def handle_flagtree_backend(): if flagtree_backend: print(f"flagtree_backend is {flagtree_backend}") extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv and flagtree_backend != "aipu": + if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"): ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - default_backends.append("flir") + if flagtree_backend != "tsingmicro": + default_backends.append("flir") if use_triton_shared: default_backends.append("triton_shared") @@ -335,7 +336,7 @@ def check_env(env_val): file="iluvatar-llvm18-x86_64", condition=("iluvatar" == flagtree_backend), url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) @@ -344,7 +345,7 @@ def check_env(env_val): file="XTDK-llvm18-ubuntu2004_x86_64", condition=("xpu" == flagtree_backend), url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm18-ubuntu2004_x86_64.tar", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) @@ -355,10 +356,10 @@ def check_env(env_val): cache.store( files=("clang", "xpu-xxd", "xpu3-crt.xpu", "xpu-kernel.t", "ld.lld", "llvm-readelf", "llvm-objdump", "llvm-objcopy"), condition=("xpu" == flagtree_backend), - copy_src_path=f"{os.environ.get('LLVM_BUILD_DIR','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") + copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") cache.store(files=("libclang_rt.builtins-xpu3.a", "libclang_rt.builtins-xpu3s.a"), - condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_BUILD_DIR','')}/lib/linux", + condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/lib/linux", copy_dst_path="third_party/xpu/backend/xpu3/lib/linux") cache.store(files=("include", "so"), condition=("xpu" == flagtree_backend), @@ -370,6 +371,16 @@ def check_env(env_val): condition=("mthreads" == flagtree_backend), url= "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# tsingmicro +cache.store( + file="tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64", + condition=("tsingmicro" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt new file mode 100644 index 000000000..ef84b07e7 --- /dev/null +++ b/third_party/tsingmicro/CMakeLists.txt @@ -0,0 +1,31 @@ +if(NOT DEFINED TX8_HOME) + if(DEFINED ENV{TX8_HOME}) + set(TX8_HOME $ENV{TX8_HOME}) + else() + message(FATAL_ERROR "TX8_HOME environment variable is not defined") + endif() +endif() + +set(XUANTIE_NAME Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/crt/include) +include_directories(${TX8_HOME}/include) +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(bin) +add_subdirectory(crt) +if(TRITON_BUILD_PYTHON_MODULE) + # FIXME: Unify the libraries for TsingMicro into fewer ones + add_triton_plugin(TritonTsingMicro ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_tsingmicro.cc + LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR + Tx81IR TritonTilingExtIR TritonStructuredIR TritonToCoreDialects + TritonToLinalg TritonToStructured StructuredToMemref LinalgToMagicKernel + TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81) + target_link_libraries(TritonTsingMicro PRIVATE Python3::Module pybind11::headers) +endif() +#if(TRITON_BUILD_UT) +# add_subdirectory(unittest) +#endif() +#add_subdirectory(test) diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py new file mode 100644 index 000000000..5a60dd12d --- /dev/null +++ b/third_party/tsingmicro/backend/compiler.py @@ -0,0 +1,347 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes +from triton.runtime.cache import get_cache_manager +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from types import ModuleType +import hashlib +import tempfile +import os +import re +import shutil +import subprocess +import functools +from pathlib import Path + + +def _get_tsm_opt_path() -> str: + return os.path.join(os.path.dirname(__file__), "bin", "tsingmicro-opt") + + +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_SYSPATH", "") + if path == "": + raise Exception("LLVM_SYSPATH is not set.") + return os.path.join(path, "bin", bin_name) + + +def _get_tx8_path(sub_name: str) -> str: + path = os.getenv("TX8_HOME", "") + if path == "": + raise Exception("TX8_HOME is not set.") + return os.path.join(path, sub_name) + + +def _dump_ir_if_needed(files): + path = os.getenv("TRITON_DUMP_PATH", "") + if not path: + return + + os.makedirs(path, exist_ok=True) + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + + +# Build a accelerator controller ELF +def compile_accelerator(): + # TODO : cache mechanism + # name = "npu_" + name + # key = hashlib.sha256(src.encode("utf-8")).hexdigest() + # cache = get_cache_manager(key) + # cache_path = cache.get_file(f"{name}.so") + + # if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, f"{name}.so") + dst_path = "/tmp/kernel.so" + libc_lib = os.path.join(_get_tx8_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf", + "lib", "rv64imafdc", "lp64d") + libvr_path = os.path.join(os.path.dirname(__file__), "lib") + clang_path = _get_llvm_bin_path("clang") + lld_path = _get_llvm_bin_path("ld.lld") + tx8_lib = _get_tx8_path("lib") + subprocess.check_call([ + clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2", + f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d", + "-Wl,--no-dynamic-linker", + # FIXME: Hardcoded path + "/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive", + "-linstr_tx81", # Tx81 intrinsic API + "-lvr", # Wrapper API of Tx81 intrinsic + "-Wl,--no-whole-archive", "-lm", "-o", dst_path + ]) + + _dump_ir_if_needed([dst_path]) + with open(dst_path, 'rb') as f: + so = f.read() + return so + + +def _ttir_to_coreir(mod): + # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tt.mlir") + dst_path = os.path.join(tmpdir, "core.mlir") + Path(src_path).write_text(ttir_code) + tsm_opt_path = _get_tsm_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ + tsm_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize=allow-return-allocs-from-loops", + #"--mlir-print-debuginfo", + "-o", dst_path + ]) + return Path(dst_path).read_text() + + +def _optimize_coreir(coreir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return coreir + + +def _coreir_to_mkir(mod): + # Get core dialects as string + coreir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "core.mlir") + dst_path = os.path.join(tmpdir, "mk.mlir") + Path(src_path).write_text(coreir_code) + tsm_opt_path = _get_tsm_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ + tsm_opt_path, src_path, "--core-dialects-to-mk", + #"--mlir-print-debuginfo", + "-o", dst_path + ]) + return Path(dst_path).read_text() + + +def _optimize_mkir(mkir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return mkir + + +def _coreir_to_txir(mod): + # Get core dialects as string + coreir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "core.mlir") + dst_path = os.path.join(tmpdir, "tx.mlir") + Path(src_path).write_text(coreir_code) + tsm_opt_path = _get_tsm_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ + tsm_opt_path, src_path, "--expand-strided-metadata", + "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load + "--mk-to-tx81", "--cse", # unused memref.subview/memref.reinterpret + #"--mlir-print-debuginfo", + "-o", dst_path + ]) + return Path(dst_path).read_text() + + +def _optimize_txir(txir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return txir + + +def _txir_to_llir(mod, metadata): + txir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tx.mlir") + llvmir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(src_path).write_text(txir_code) + tsm_opt_path = _get_tsm_opt_path() + _dump_ir_if_needed([src_path]) + # Tx81 and core dialects to LLVM-MLIR + args = [ + tsm_opt_path, src_path, + # Use tx81-memref-to-llvm to replace "--finalize-memref-to-llvm". + "--tx81-memref-to-llvm", "--convert-scf-to-cf", "--convert-math-to-llvm", + "--convert-cf-to-llvm", # need exec before "convert-func-to-llvm" + "--convert-func-to-llvm", # need exec before "kernel-arg-buffer", otherwise un-rank memref will translate to int(rank) + ptr + ] + + args.append( + "--kernel-arg-buffer" + ) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel + + # other pass + args += [ + "--tx81-to-llvm", "--convert-arith-to-llvm", # need exec last since arith.const conversion + # Remove all unrealized casts created + "--reconcile-unrealized-casts", "--canonicalize", + #"--mlir-print-debuginfo", + "-o", llvmir_path + ] + + subprocess.check_call(args) + + _dump_ir_if_needed([llvmir_path]) + + llvm_file = os.getenv("CUSTOMIZED_IR", "") + if (llvm_file != ""): + llvmir_path = os.getenv("TRITON_DUMP_PATH", "") + + if not llvmir_path: + return + + llvmir_path = os.path.join(llvmir_path, llvm_file) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + subprocess.check_call([mlir_translate_path, llvmir_path, "--mlir-to-llvmir", "-o", llir_path]) + + _dump_ir_if_needed([llir_path]) + return Path(llir_path).read_text() + + +def _mkir_to_llir(mkir: str): + with tempfile.TemporaryDirectory() as tmpdir: + mkir_path = os.path.join(tmpdir, "mk.mlir") + llvmir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(mkir_path).write_text(mkir) + mlir_opt_path = _get_llvm_bin_path("mlir-opt") + # MagicKernel-MLIR to LLVM-MLIR + subprocess.check_call([ + mlir_opt_path, mkir_path, "--convert-linalg-to-affine-loops", + # Note: eliminate-empty-tensors fails when there are multiple func.return ops + # in a single kernel which are the results of early returns. + # See python/examples/test_early_return.py for examples. + # We disable this pass for now since performance on CPU isn't the main + # focus at the moment. + # "--eliminate-empty-tensors", + "--empty-tensor-to-alloc-tensor", "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", "--convert-linalg-to-loops", "--expand-strided-metadata", "--convert-scf-to-cf", + "--convert-arith-to-llvm", "--convert-math-to-llvm", "--convert-complex-to-llvm", + "--convert-vector-to-llvm", "--convert-index-to-llvm", "--memref-expand", "--finalize-memref-to-llvm", + "--convert-func-to-llvm", "--convert-cf-to-llvm", + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + "--lower-affine", "--convert-arith-to-llvm", + # Remove all unrealized casts created + "--canonicalize", "--reconcile-unrealized-casts", "--mlir-print-debuginfo", "-o", llvmir_path + ]) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + subprocess.check_call([mlir_translate_path, llvmir_path, "--mlir-to-llvmir", "-o", llir_path]) + _dump_ir_if_needed([mkir_path, llvmir_path, llir_path]) + return Path(llir_path).read_text() + + +def _optimize_llir(llir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return llir + + +def _llir_to_bin(llir: str, metadata): + pattern = r"define void @(\w+)\(.+" + matches = re.findall(pattern, llir) + assert len(matches) == 1 + metadata["name"] = matches[0] + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, "kernel.so") + dst_path = "/tmp/kernel.o" + Path(src_path).write_text(llir) + clang_path = _get_llvm_bin_path("clang++") + subprocess.check_call([ + clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-o", + dst_path + ]) + + _dump_ir_if_needed([dst_path]) + + # compile kernel and intrinsic wrapper to shared library + return compile_accelerator() + + +@dataclass(frozen=True) +class TXDAOptions: + debug: bool = False + arch: str = None + num_warps: int = 0 + num_ctas: int = 0 + num_stages: int = 1 + enable_warp_specialization: bool = False + enable_fp_fusion: bool = False + extern_libs = None + cluster_dims: tuple = (1, 1, 1) + shared: bool = False + allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + sanitize_overflow: bool = True + + def __post_init__(self): + pass + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +class TXDABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'txda' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.binary_ext = "so" + + def parse_options(self, opts) -> Any: + args = {'arch': self.target.arch} + args.update({k: opts[k] for k in TXDAOptions.__dataclass_fields__.keys() if k in opts}) + return TXDAOptions(**args) + + def get_codegen_implementation(self, options): + codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)} + return codegen_fns + + def pack_metadata(self, metadata): + # Note: We actually don't need any of these except for the name which is + # used in the launch function in driver.py. Putting these in so we're + # consistent with other backends + return (metadata.num_warps, metadata.num_ctas, metadata.shared, metadata.cluster_dims[0], + metadata.cluster_dims[1], metadata.cluster_dims[2], metadata.name) + + # Our compilation pipeline isn't in python like nvidia or amd, no need to load + def load_dialects(self, ctx): + return + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["coreir"] = lambda src, metadata: _optimize_coreir(_ttir_to_coreir(src)) + # stages["mkir"] = lambda src, metadata: _optimize_mkir(_coreir_to_mkir(src)) + stages["txir"] = lambda src, metadata: _optimize_txir(_coreir_to_txir(src)) + stages["llir"] = lambda src, metadata: _optimize_llir(_txir_to_llir(src, metadata)) + stages["so"] = lambda src, metadata: _llir_to_bin(src, metadata) + + @functools.lru_cache() + def hash(self): + return self.target + + # The CPU backend does not use any extra python modules, return an empty dictionary + def get_module_map(self) -> Dict[str, ModuleType]: + return {} diff --git a/third_party/tsingmicro/backend/driver.cpp b/third_party/tsingmicro/backend/driver.cpp new file mode 100644 index 000000000..a9af39cfc --- /dev/null +++ b/third_party/tsingmicro/backend/driver.cpp @@ -0,0 +1,74 @@ +//===---------------------------- driver.c --------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Tx81 platform device side runtime interface for python. +// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + // Extract device properties + // Note: We're mapping Tx81 properties to fields expected by Triton + int max_shared_mem = 1024 * 1024 * 3; // Default 3MB + // int multiprocessor_count = device->tile_num; + int multiprocessor_count = 1; + int sm_clock_rate = 1000; // Placeholder + int mem_clock_rate = 2000; // Placeholder + int mem_bus_width = 256; // Placeholder + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "multiprocessor_count", + multiprocessor_count, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + + int32_t n_regs = 256; + int32_t n_spills = 0; + // Return values to Python including module, function, n_regs, n_spills + return Py_BuildValue("(KKii)", "module {}", "void @add_kernel() {}", n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided binary into Tx81 driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given Tx81 device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "tx81_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_tx81_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py new file mode 100644 index 000000000..dc646291b --- /dev/null +++ b/third_party/tsingmicro/backend/driver.py @@ -0,0 +1,745 @@ +# +# This file implements the triton kernel driver interfaces where are used in +# triton/python/triton/compiler/compiler.py. +# For how the interface in driver class is used, see the implementation of the +# file above. +# +import hashlib +import tempfile +import os +import subprocess +import importlib.util +import shutil +import sysconfig +import atexit +from pathlib import Path +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import GPUDriver +from triton.backends.compiler import GPUTarget + + +def extend_torch(): + import torch + from torch.utils import cpp_extension, rename_privateuse1_backend, generate_methods_for_privateuse1_backend + module = cpp_extension.load( + name="txda", + sources=[os.path.dirname(__file__) + "/txda_device.cpp"], + #runtime include path + extra_include_paths=[""], + #runtime *.so path + extra_ldflags=[""], + extra_cflags=["-g"], + verbose=True, + ) + torch.utils.rename_privateuse1_backend("txda") + torch._register_device_module("txda", module) + generate_methods_for_privateuse1_backend(for_storage=True) + + +def _get_tx8_path(bin_name: str) -> str: + path = os.getenv("TX8_HOME", "") + if path == "": + raise Exception("TX8_HOME is not set.") + return os.path.join(path, bin_name) + + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dirs = [ + os.path.join(dirname, "include"), + os.path.realpath(_get_tx8_path("include")), + os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), + os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include") +] +library_dirs = [ + os.path.join(dirname, "lib"), + os.path.realpath(_get_tx8_path("lib")), + os.path.join(sysconfig.get_path('platlib'), "torch", "lib") +] +libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] + + +def _dump_ir_if_needed(files): + path = os.getenv("TRITON_DUMP_PATH", "") + if not path: + return + + os.makedirs(path, exist_ok=True) + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + cc = clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-std=c++17", "-Wno-psabi", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + cc_cmd += [f"-Wl,-rpath,{dir}" for dir in library_dirs] + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + return so + + +# Build a native ELF on the platform running this python script +def compile_native(src, name): + fname = "native_" + name + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{fname}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, f"{name}.cpp") + with open(src_path, "w") as f: + f.write(src) + _dump_ir_if_needed([src_path]) + so = _build(name, src_path, tmpdir, library_dirs, include_dirs, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{fname}.so", binary=True) + _dump_ir_if_needed([cache_path]) + + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# -------------------- Launcher ---------------------------- +def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" + if ty[0] == '*': + return "PyObject*" + if ty == "constexpr": + return "PyObject*" + return _ty_to_cpp(ty) + + +def _format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty in ("constexpr", "nvTmaDesc"): + return "O" + return { + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "L", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[_ty_to_cpp(ty)] + + +def make_launcher(constants, signature, kernel_name): + # Basic declarations + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") + args_format = ''.join([_format_of(ty) for ty in signature.values()]) + format = "iiiOOOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # Parameters to pass to the kernel function + kernel_parameters = ', '.join( + f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" + for i, ty in signature.items() + if ty != "constexpr") + kernel_parameters += ', ' if kernel_parameters else '' + + return f""" +#include +#include +#include +#include +#include +#include +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +//#include +#include +#include +#include +#include +#include +#include "hrt_interface.h" +#include "hrt_common.h" + +// The design of kernel argument buffer: +// The offset starts from the whole kernel buffer +// +------------------------------------------------------------------------+ +// | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 8 bytes | 8 bytes | +// | No. kargs | gridX | gridY | gridZ | karg1 offset | karg2 offset | +// +------------------------------------------------------------------------+ +// .......................... Metadata buffer................................ +// +// +------------------------------------------------------------------------+ +// | ... | 8 bytes | n bytes | n bytes | | n bytes | +// | ... | kargn offset | karg1 data | karg2 data | ... | kargn data | +// +------------------------------------------------------------------------+ +// ^ ^ ^ +// karg1 offset karg2 offset kargn offset +// ... Metadata buffer... | ............ kernel arg buffer .................. + +enum DATA_TYPE {{ + SCALAR, + POINT, +}}; + +// A kernel argument +struct KernelArg {{ + // The actual kernel argument: tensor or scalar + union Data {{ + void* ptr; // Pointer to the tensor data + uint64_t scalar; // Scalar data + }} data; + size_t size; // The size of the kernel argument + int data_type; + + KernelArg(void *ptr, size_t s) : size(s) {{ + data.ptr = ptr; + data_type = POINT; + }} + + KernelArg(uint64_t v, size_t s) : size(0) {{ + data.scalar = v; + data_type = SCALAR; + }} + +}}; + +extern "C" {{ + // The kernel arguments includes: + // 1. The actual kernel argument in arg_decls + // 2. The group size: gridX, gridY, gridZ + // 3 The thread id in each direction: x, y, z + void {kernel_name}({arg_decls}, int, int, int, int, int, int); +}} + +// Global device vector +static std::vector g_tx81_devices; +static bool g_runtime_initialized = false; + +// FIXME: Hardcoded path +std::string chip_out = "/tmp/chip_out/node0/"; +std::string kernel_file = "/tmp/kernel.so"; +std::string kernel_fun_name = "{kernel_name}"; +uint32_t sharedMemBytes = 0; + +typedef void* Stream_t; + +static uint64_t get_phy_addr(uint64_t logic_addr) {{ + uint32_t card_id; + uint64_t addr; + uint64_t size; + TsmMemGetInfo(logic_addr, card_id, addr, size); + return addr; +}} + + +// Initialize Tx81 runtime +bool init_tx81_runtime() {{ + if (g_runtime_initialized) {{ + return true; // Already initialized + }} + + // Initialize the Tx81 runtime + if (TsmInitRuntime() != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); + return false; + }} + + // Get device count + uint32_t device_num = 0; + if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to get Tx81 device count or no devices found"); + TsmDeInitRuntime(); + return false; + }} + + // FIXME: Hardcoded + // Set up devices - for simplicity, we're using a 1x1 configuration + uint32_t first_phy_id = 0; + uint32_t card_x = 1; + uint32_t card_y = 1; + + TsmDevice *dev = new TsmDevice(); + if (TsmSetDevice(&dev, 0, first_phy_id) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); + TsmDeInitRuntime(); + return false; + }} + g_tx81_devices.push_back(dev); + + // FIXME: Hardcoded + TsmModel *new_model = new TsmModel(); + // Create a vector of models + std::vector kmodel_vec = {{new_model}}; + std::string option = "-O2"; + CompileOption compl_option = {{}}; + compl_option.comp_enable = 0; // Use prebuilt binary + compl_option.chip_x = 1; //单卡 + compl_option.chip_y = 1; + compl_option.check_enable = true; + compl_option.enable_kcore_bin = 1; + compl_option.enable_kcore_so = 1; + new_model->case_dir = chip_out; + + for (TsmDevice * dev : g_tx81_devices) {{ + if (TsmCompileMultiGraph(dev, *new_model, option, compl_option) != RET_SUCCESS) {{ + for (uint32_t dev_index = 0; dev_index < g_tx81_devices.size(); ++dev_index) {{ + if (TsmResetDevice(g_tx81_devices[dev_index]) != RET_SUCCESS) {{ + return false; + }} + }} + return false; + }} + }} + + // Initialize all devices + for (auto* dev : g_tx81_devices) {{ + if (TsmLaunch(dev, *new_model) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); + TsmReleaseDevice(dev); + TsmResetDevice(dev); + return false; + }} + + if (TsmSetMonitorInfo(dev) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); + TsmReleaseDevice(dev); + TsmResetDevice(dev); + return false; + }} + }} + + delete new_model; + g_runtime_initialized = true; + + return true; +}} + +// Clean up Tx81 runtime resources +static PyObject* cleanup_tx81_runtime(PyObject* self, PyObject* args) {{ + if (!g_runtime_initialized) {{ + Py_RETURN_NONE; + }} + + for (auto* dev : g_tx81_devices) {{ + if (TsmSetTerminate(dev) != RET_SUCCESS) {{ + Py_RETURN_NONE; + }} + // Reset and release each device + TsmReleaseDevice(dev); + TsmResetDevice(dev); + delete dev; + }} + g_tx81_devices.clear(); + TsmDeInitRuntime(); + g_runtime_initialized = false; + Py_RETURN_NONE; +}} + +TSM_RETCODE argsToDevMemArray(TsmDevice *dev, std::vector &kargs, + std::vector &rtKargs, std::vector &devAddrs) {{ + int count = 0; + for (KernelArg& karg : kargs) {{ + if (karg.data_type == POINT) {{ + TsmDevicePtr dev_buffer; + if (TsmDeviceMalloc(dev, dev_buffer, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceMalloc"); + return RET_ERROR; + }} + + if (TsmMemcpyH2D(dev_buffer, karg.data.ptr, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); + return RET_ERROR; + }} + devAddrs.push_back(dev_buffer); + // FIXME: rank + rtKargs.push_back(1); + rtKargs.push_back(get_phy_addr(dev_buffer)); + + count++; + }} + else {{ + rtKargs.push_back(karg.data.scalar); + }} + }} + return RET_SUCCESS; +}} + +TSM_RETCODE devMemArrayToArgs(TsmDevice *dev, std::vector &kargs, + std::vector &devAddrs) {{ + + int count = 0; + for (KernelArg& karg : kargs) {{ + if (karg.data_type == POINT) {{ + uint64_t dev_buffer = devAddrs[count++]; + if (TsmMemcpyD2H(karg.data.ptr, dev_buffer, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); + return RET_ERROR; + }} + }} + }} + return RET_SUCCESS; +}} + +TSM_RETCODE devMemFree(TsmDevice *dev, std::vector &devAddrs) {{ + for (uint64_t dev_buffer : devAddrs) {{ + if (TsmDeviceFree(dev_buffer) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceFree"); + return RET_ERROR; + }} + }} + return RET_SUCCESS; +}} + +TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) {{ + if (bootpm_dev != 0) {{ + if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) {{ + return RET_ERROR; + }} + bootpm_dev = 0; + }} + return RET_SUCCESS; +}} + +static void _launch(int gridX, int gridY, int gridZ, std::vector kargs) {{ + std::vector &devices = g_tx81_devices; + + if (gridX*gridY*gridZ <= 0) {{ + return; // No work to do + }} + + // TODO::mv + uint64_t kernel_len = 0; + uint8_t* kernel_ptr = read_file_data(kernel_file, kernel_len); + if (kernel_ptr == nullptr) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to read kernel so"); + TsmDeInitRuntime(); + return; + }} + + // prepare data/ load kernel/run/unload kernel/get out data/release memory + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ + // Allocate the device memory for all kernel arguments + std::vector devAddrs; + std::vector rtKargs; + + if (argsToDevMemArray(devices[dev_index], kargs, rtKargs, devAddrs) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to argsToDevMemArray"); + TsmDeInitRuntime(); + return; + }} + + rtKargs.push_back(gridX); + rtKargs.push_back(gridY); + rtKargs.push_back(gridZ); + rtKargs.push_back(0); + rtKargs.push_back(0); + rtKargs.push_back(0); + + // TSM_RETCODE TsmKernelLaunch(TsmDevice *dev, const char *func_name, void *kernel_ptr, uint32_t kernel_len, + // uint32_t grid_dim, uint32_t block_dim, void *args, uint32_t args_len); + if (TsmKernelLaunch(devices[dev_index], kernel_fun_name.c_str(), (void*)kernel_ptr, kernel_len, + gridX, 1, (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t)) != RET_SUCCESS){{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmKernelLaunch"); + TsmDeInitRuntime(); + }} + if (devMemArrayToArgs(devices[dev_index], kargs, devAddrs) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to devMemArrayToArgs"); + TsmDeInitRuntime(); + return; + }} + + // getchar(); + + // TsmUnloadKernel(devices[dev_index], kmodel_vec); + + if (devMemFree(devices[dev_index], devAddrs) != RET_SUCCESS) {{ + return; + }} + }} +}} + +// Structure to represent a device pointer +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static size_t getTensorStorageSize(PyObject* tensor_obj) {{ + const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); + return tensor.storage().nbytes(); +}} + +// Extract tensor raw ptr +static void* extractTensor(PyObject* tensor_obj) {{ + const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); + torch::Tensor contiguous_tensor = tensor.contiguous(); + return contiguous_tensor.data_ptr(); +}} + +static PyObject* init_runtime(PyObject* self, PyObject* args) {{ + const char* _chip_out; + if (!PyArg_ParseTuple(args, "s", &_chip_out)) {{ + return NULL; + }} + chip_out = _chip_out; + + // Initialize Tx81 runtime during module import + if (!init_tx81_runtime()) {{ + return NULL; + }} + + return Py_None; +}} + +// Python module launch function +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + // FIXME: Extra 2 args: + PyObject *dummy1 = NULL; + PyObject *dummy2 = NULL; + // Define the actual kernel arguments + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + + // Init kernel arguments from python side + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook, + &dummy1, &dummy2{args_list})) {{ + return NULL; + }} + +#if 0 // FIXME: kernel_metadata is not correctly inited + // Extract metadata for consistency with other drivers + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, + &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // Call the enter hook if provided + if (launch_enter_hook != Py_None) {{ + PyObject* hook_args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, hook_args); + Py_DECREF(hook_args); + if (!ret) + return NULL; + }} +#endif + + // Construct a data kernel arguments list data structure + std::vector kargs; + //{' '.join([f"kargs.emplace_back(_arg{i}, PyObject_Size(_arg{i})*4);" if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} + {' '.join([f"kargs.emplace_back(extractTensor(_arg{i}), getTensorStorageSize(_arg{i}));" + if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" + for i, ty in signature.items() if ty != "constexpr"])} + + // Launch the kernel + _launch(gridX, gridY, gridZ, kargs); + if (PyErr_Occurred()) {{ + return NULL; + }} + + // Call the exit hook if provided + if (launch_exit_hook != Py_None) {{ + PyObject* hook_args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, hook_args); + Py_DECREF(hook_args); + if (!ret) + return NULL; + }} + + // Return None to Python + Py_INCREF(Py_None); + return Py_None; +}} + +// Python module method definitions +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{"init_runtime", init_runtime, METH_VARARGS, "Init runtime with chip_out dir"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +// Python module definition +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, // documentation + -1, // size + ModuleMethods +}}; + +static PyMethodDef cleanup_method = {{ + "cleanup_tx81_runtime", + (PyCFunction)cleanup_tx81_runtime, + METH_NOARGS, + "Cleanup Tx81 runtime resources" +}}; + +// Python module initialization function +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) {{ + return NULL; + }} + + PyModule_AddFunctions(m, ModuleMethods); + + // Register an atexit handler to cleanup Tx81 runtime + PyObject* atexit_module = PyImport_ImportModule("atexit"); + if (atexit_module) {{ + PyObject* cleanup_func = PyCFunction_New(&cleanup_method, NULL); + if (cleanup_func) {{ + PyObject* result = PyObject_CallMethod(atexit_module, "register", "O", cleanup_func); + Py_XDECREF(result); + Py_DECREF(cleanup_func); + }} + Py_DECREF(atexit_module); + }} + + return m; +}} +""" + + +class TXDAUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(TXDAUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + src = Path(os.path.join(dirname, "driver.cpp")).read_text() + mod = compile_native(src, "tx81_utils") + # # NOTE: The triton compiler.py framework requires these 2 interface. + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +# Launch cross compiled runtime program on controller +class TXDALauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + + # Compiler runtime kernel launcher source code + launcher_src = make_launcher(constants, signature, src.fn.__name__) + mod = compile_native(launcher_src, "__triton_launcher") + self.launch = mod.launch + chip_out = os.path.join(_get_tx8_path("chip_out"), "node0") + chip_out = chip_out + os.sep + mod.init_runtime(chip_out) + + def __call__(self, *args, **kwargs): + # args: 0: gridX, 1: gridY, 2: gridZ, + # 3: kernel_metadata?, 4: launch_metadata?, + # 5: a tuple(0, 0, False, 1, 1, 1, 'add_kernel'), # this is probably kernel metadata + # 6: None, 7: None, 8: None, + # 9~N: Actual triton kernel args. + self.launch(*args, **kwargs) + + +class TXDADriver(GPUDriver): + + def __init__(self): + super().__init__() + extend_torch() + self.utils = TXDAUtils() + self.launcher_cls = TXDALauncher + import torch + # Needs to overwrite GPUDriver base methods + self.get_current_stream = torch.txda.current_stream + self.get_current_device = torch.txda.current_device + self.set_current_device = torch.txda.set_device + atexit.register(torch.txda.cleanup_device) + + @staticmethod + def is_active(): + try: + #import torch + #return torch.txda.is_available() + return True + except ImportError: + return False + + def get_current_target(self): + capability = 1 + warp_size = 16 + return GPUTarget("txda", capability, warp_size) + + def get_active_torch_device(self): + import torch + # torch.txda.init_device() + return torch.device("txda", self.get_current_device()) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_device_interface(self): + import torch + return torch.txda diff --git a/third_party/tsingmicro/backend/name.conf b/third_party/tsingmicro/backend/name.conf new file mode 100644 index 000000000..1340763be --- /dev/null +++ b/third_party/tsingmicro/backend/name.conf @@ -0,0 +1 @@ +tsingmicro diff --git a/third_party/tsingmicro/backend/txda_device.cpp b/third_party/tsingmicro/backend/txda_device.cpp new file mode 100644 index 000000000..ac46d67ac --- /dev/null +++ b/third_party/tsingmicro/backend/txda_device.cpp @@ -0,0 +1,180 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace detail { + +C10_REGISTER_GUARD_IMPL( + PrivateUse1, c10::impl::NoOpDeviceGuardImpl); + +} +} // namespace at + +struct TXDADeviceAllocator final : at::Allocator { + TXDADeviceAllocator() {} + + at::DataPtr allocate(size_t nbytes) override { + void *data = c10::alloc_cpu(nbytes); + return {data, nullptr, &ReportAndDelete, + at::Device(at::DeviceType::PrivateUse1, 0)}; + } + + static void ReportAndDelete(void *ptr) { + if (!ptr) { + return; + } + // TsmDeviceFree((uint64_t)ptr) + c10::free_cpu(ptr); + } + + at::DeleterFnPtr raw_deleter() const override { return &ReportAndDelete; } + void copy_data(void *dest, const void *src, std::size_t count) const final { + default_copy_data(dest, src, count); + } +}; + +// register device allocator +static TXDADeviceAllocator global_txda_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_txda_alloc); + +// to.Device +at::Tensor txda_to_device(const at::Tensor &self, at::Device device, + at::ScalarType dtype, bool non_blocking, bool copy, + c10::optional memory_format) { + // TsmMemcpyH2D(); + + TORCH_CHECK(self.is_cpu() || + self.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + // Some dummy asserts for the basic use case: inputs are the same size / + // dtype, all contiguous. + TORCH_CHECK(self.scalar_type() == dtype); + TORCH_CHECK(self.is_contiguous()); + + if (device != at::DeviceType::CPU) { + return at::empty(self.sizes(), self.options()); + } + + auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, + false, memory_format); + memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes()); + return out; +} + +// _copy_from +at::Tensor txda__copy_from(const at::Tensor &self, const at::Tensor &dst, + bool non_blocking) { + // TsmMemcpyD2H(); + + TORCH_CHECK(self.is_cpu() || + self.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + TORCH_CHECK(dst.is_cpu() || + dst.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + + // Some dummy asserts for the basic use case: inputs are the same size / + // dtype, all contiguous. + TORCH_CHECK(self.sizes() == dst.sizes()); + TORCH_CHECK(self.scalar_type() == dst.scalar_type()); + TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); + + std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), + self.storage().nbytes()); + return dst; +} + +at::Tensor txda_empty_memory_format( + at::IntArrayRef size, std::optional dtype, + std::optional layout, std::optional device, + std::optional pin_memory, + std::optional memory_format) { + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic(size, &global_txda_alloc, private_use_ks, + c10::dtype_or_default(dtype), memory_format); +} + +at::Tensor txda_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + auto dtype = c10::dtype_or_default(dtype_opt); + return at::detail::empty_strided_generic(size, stride, &global_txda_alloc, + private_use_ks, dtype); +} + +at::Tensor txda_as_strided(const at::Tensor &input, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + return at::cpu::as_strided(input, size, stride, storage_offset); +} + +at::Tensor &txda_fill__scalar(at::Tensor &self, const at::Scalar &value) { + TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, + "only support txda"); + TORCH_CHECK(self.is_contiguous()); + TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float); + + auto _data = static_cast(self.mutable_data_ptr()); + for (size_t idx = 0; idx < self.numel(); idx++) { + _data[idx] = value.toFloat(); + } + + return self; +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("to.Device", &txda_to_device); + m.impl("fill_.Scalar", &txda_fill__scalar); + m.impl("_copy_from", &txda__copy_from); + m.impl("empty.memory_format", &txda_empty_memory_format); + m.impl("empty_strided", &txda_empty_strided); + m.impl("as_strided", &txda_as_strided); +} + +bool init_device() { + // return init_txda_runtime(); + return true; +} + +bool cleanup_device() { + // cleanup_txda_runtime(); + return true; +} + +int current_device() { return 0; } + +int current_stream(int id) { return 0; } + +void set_device(int id) {} + +c10::Device get_txda_device() { + return c10::Device(c10::DeviceType::PrivateUse1, 0); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("current_device", ¤t_device, "get current tx device"); + m.def("current_stream", ¤t_stream, "get current tx stream"); + m.def("set_device", &set_device, "set tx device"); + m.def("get_txda_device", &get_txda_device, "get tx device"); + m.def("init_device", &init_device, "initialize tx device"); + m.def("cleanup_device", &cleanup_device, "cleanup tx device"); +} diff --git a/third_party/tsingmicro/bin/CMakeLists.txt b/third_party/tsingmicro/bin/CMakeLists.txt new file mode 100644 index 000000000..38c4ea7b6 --- /dev/null +++ b/third_party/tsingmicro/bin/CMakeLists.txt @@ -0,0 +1,88 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(tsingmicro-opt tsingmicro-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(tsingmicro-opt) +target_link_libraries(tsingmicro-opt PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + ZTCAnalysis + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-opt) + +add_llvm_executable(tsingmicro-reduce tsingmicro-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(tsingmicro-reduce) + +llvm_update_compile_flags(tsingmicro-reduce) +target_link_libraries(tsingmicro-reduce PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-reduce) + +add_llvm_executable(tsingmicro-lsp tsingmicro-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(tsingmicro-lsp) + +llvm_update_compile_flags(tsingmicro-lsp) +target_link_libraries(tsingmicro-lsp PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-lsp) + + +add_llvm_executable(tsingmicro-llvm-opt + tsingmicro-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(tsingmicro-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(tsingmicro-llvm-opt) diff --git a/third_party/tsingmicro/bin/RegisterTritonDialects.h b/third_party/tsingmicro/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..95cb7fbd5 --- /dev/null +++ b/third_party/tsingmicro/bin/RegisterTritonDialects.h @@ -0,0 +1,181 @@ +#pragma once +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "amd/include/TritonAMDGPUTransforms/Passes.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +// Below headers will allow registration to ROCm passes +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "TritonAMDGPUTransforms/TritonGPUConversion.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "nvidia/include/NVGPUToLLVM/Passes.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" +#include "triton-shared/Conversion/TritonToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" +#include "magic-kernel/Conversion/LinalgToMK/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h" + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" + +namespace mlir { +namespace test { +void registerTestAliasPass(); +void registerTestAlignmentPass(); +void registerTestAllocationPass(); +void registerTestMembarPass(); +void registerTestTritonAMDGPURangeAnalysis(); +} // namespace test +} // namespace mlir + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::registerTritonPasses(); + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::registerLinalgPasses(); + mlir::registerTritonNvidiaGPUPasses(); + mlir::test::registerTestAliasPass(); + mlir::test::registerTestAlignmentPass(); + mlir::test::registerTestAllocationPass(); + mlir::test::registerTestMembarPass(); + mlir::test::registerTestTritonAMDGPURangeAnalysis(); + mlir::triton::registerTritonToLinalgPass(); + mlir::triton::registerTritonToStructuredPass(); + mlir::triton::registerTritonArithToLinalgPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerStructuredToMemrefPasses(); + mlir::triton::registerTritonToCoreDialectsPass(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::gpu::registerAllocateSharedMemoryPass(); + mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); + mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); + mlir::triton::registerConvertWarpSpecializeToLLVM(); + mlir::triton::registerConvertTritonGPUToLLVMPass(); + mlir::triton::registerConvertNVGPUToLLVMPass(); + mlir::registerLLVMDIScope(); + + // Core dialects to MK layer conversion passes + mlir::triton::registerTx81MemrefToLLVMPass(); + mlir::triton::registerLinalgToMKPass(); + mlir::triton::registerCoreDialectsToMKPass(); + + // TsingMicro specific conversion passes + mlir::triton::registerMKToTx81Pass(); + mlir::triton::registerTx81ToLLVMPass(); + mlir::triton::registerKernelArgBufferPass(); + + // TritonAMDGPUToLLVM passes + mlir::triton::registerConvertTritonAMDGPUToLLVM(); + mlir::triton::registerConvertBuiltinFuncToLLVM(); + mlir::triton::registerDecomposeUnsupportedAMDConversions(); + mlir::triton::registerOptimizeAMDLDSUsage(); + + // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUAccelerateMatmul(); + mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUHoistLayoutConversions(); + mlir::registerTritonAMDGPUReorderInstructions(); + mlir::registerTritonAMDGPUBlockPingpong(); + mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + + // FIXME: May not need all of these + // mlir::registerAllDialects(registry); + // Register all external models. + // affine::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + mlir::arith::registerShardingInterfaceExternalModels(registry); + mlir::arith::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + mlir::builtin::registerCastOpInterfaceExternalModels(registry); + mlir::cf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::gpu::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::LLVM::registerInlinerInterface(registry); + mlir::NVVM::registerInlinerInterface(registry); + mlir::linalg::registerAllDialectInterfaceImplementations(registry); + mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + mlir::memref::registerAllocationOpInterfaceExternalModels(registry); + mlir::memref::registerBufferViewFlowOpInterfaceExternalModels(registry); + mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + mlir::memref::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::memref::registerMemorySlotExternalModels(registry); + mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::scf::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::shape::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerFindPayloadReplacementOpInterfaceExternalModels( + registry); + mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry); + mlir::tensor::registerSubsetOpInterfaceExternalModels(registry); + mlir::tensor::registerTilingInterfaceExternalModels(registry); + mlir::tensor::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::vector::registerBufferizableOpInterfaceExternalModels(registry); + mlir::vector::registerSubsetOpInterfaceExternalModels(registry); + mlir::vector::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + // This is need for the Bufferization pass(one-shot bufferization) + mlir::registerAllExtensions(registry); + mlir::mk::registerBufferizableOpInterfaceExternalModels(registry); + + registry.insert(); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp b/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/tsingmicro/bin/tsingmicro-lsp.cpp b/third_party/tsingmicro/bin/tsingmicro-lsp.cpp new file mode 100644 index 000000000..f95036dc6 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-lsp.cpp @@ -0,0 +1,10 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-opt.cpp b/third_party/tsingmicro/bin/tsingmicro-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-reduce.cpp b/third_party/tsingmicro/bin/tsingmicro-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp b/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp new file mode 100644 index 000000000..cc121b3e1 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp @@ -0,0 +1,232 @@ +#include "RegisterTritonDialects.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/MLIRContext.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace mlir; + +// A CLI tool to print the layout of a tensor. +// +// clang-format off +// Example usage: +// +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view +// +// An input file usually looks like: +// ''' +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// ''' +// clang-format on + +//===--------------------------------------------------------------------===// +// CLI options +//===--------------------------------------------------------------------===// + +cl::OptionCategory PrinterCategory("Available Print Options", + "Options for the tensor layout printing."); + +static cl::opt InputFile( + "i", cl::desc("File that contains the tensor data layout attributes"), + cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory)); + +static cl::opt + OutputFile("o", cl::desc("Output file to write the layout into"), + cl::init(""), cl::value_desc("filename"), + cl::cat(PrinterCategory)); + +static cl::opt + DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), + cl::value_desc("layout-string"), cl::init(""), + cl::cat(PrinterCategory)); + +static cl::list + AliasName("alias-names", + cl::desc("A list of alias names (separated by comma) of the " + "layout attributes in the input file"), + cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, + cl::ZeroOrMore, cl::cat(PrinterCategory)); + +static cl::opt UseHWPointOfView( + "use-hw-view", + llvm::cl::desc( + "Print the layout in hardware point of view. This means the output is " + "from the warp's perspective. Otherwise, the output is from the " + "tensor's perspective (e.g., each element maps to xxx thread)."), + cl::init(false), cl::cat(PrinterCategory)); + +static cl::opt TensorStr( + "t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), + cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory)); + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + // DistributedEncodingTrait and SharedEncodingTrait implements the + // toLinearLayout interface. + mlir::Attribute layout = tensorType.getEncoding(); + if (isa(layout)) { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); +} + +LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, + ArrayRef names, + TensorType tensorTy, raw_string_ostream &ss) { + if (filename.empty()) + return success(); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + ParserConfig config(context); + auto asmState = AsmParserState(); + + Block parsedIR; + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { + llvm::errs() << "Fail to parse the input file: " << filename << "\n"; + return failure(); + } + + auto printLambda = [&](StringRef name, mlir::Attribute attr) { + ss << "Print layout attribute: #" << name << " = " << attr << "\n"; + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), attr); + + return layoutPrint(rankedTensorTy, ss); + }; + + if (names.empty()) + // If no alias name is given, we print all layout attributes in the file. + for (const auto &def : asmState.getAttributeAliasDefs()) { + if (failed(printLambda(def.name, def.value))) + return failure(); + } + else { + // Print the layout attributes with the given alias names. + for (const auto &alias : names) { + auto def = asmState.getAttributeAliasDef(alias); + if (!def) { + llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; + return failure(); + } + + if (failed(printLambda(alias, def->value))) + return failure(); + + ss << "\n"; + } + } + + return success(); +} + +LogicalResult printLayoutFromString(MLIRContext *context, + StringRef layoutAttrStr, + TensorType tensorTy, + raw_string_ostream &ss) { + if (layoutAttrStr.empty()) + return success(); + + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); + if (!layout) { + llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; + return failure(); + } + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + + ss << "Print layout attribute: " << layout << "\n"; + + return layoutPrint(rankedTensorTy, ss); +} + +//===--------------------------------------------------------------------===// +// Main entry point +//===--------------------------------------------------------------------===// + +int main(int argc, char **argv) { + cl::HideUnrelatedOptions(PrinterCategory); + cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); + + DialectRegistry registry; + registerTritonDialects(registry); + + MLIRContext ctx(registry); + ctx.loadAllAvailableDialects(); + + if (TensorStr.empty()) { + llvm::errs() << "Must specify the tensor type argument\n"; + return 1; + } + + mlir::Type parsedTy = parseType(TensorStr, &ctx); + if (!parsedTy) { + llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr + << "\n"; + return 1; + } + + TensorType tensorType = dyn_cast(parsedTy); + if (!tensorType) { + llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; + return 1; + } + + std::string storage; + raw_string_ostream ss(storage); + + if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) + return 1; + + if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) + return 1; + + if (OutputFile.empty()) { + llvm::outs() << ss.str(); + } else { + std::error_code ec; + llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); + if (ec) { + llvm::errs() << "Error: " << ec.message() << " : unable to open " + << OutputFile << " for output\n"; + return 1; + } + outFs << ss.str(); + outFs.close(); + } + + return 0; +} diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt new file mode 100644 index 000000000..f294fab8c --- /dev/null +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -0,0 +1,105 @@ +cmake_minimum_required(VERSION 3.18) + +set(TARGET Tx81) +# Set TARGET from environment variable +if(NOT DEFINED TARGET) + if(DEFINED ENV{CRT_TARGET}) + set(TARGET $ENV{CRT_TARGET}) + else() + message(FATAL_ERROR "CRT_TARGET environment variable is not defined") + endif() +endif() + +if(NOT DEFINED XUANTIE_NAME) + if(DEFINED ENV{XUANTIE_NAME}) + set(XUANTIE_NAME $ENV{XUANTIE_NAME}) + else() + message(FATAL_ERROR "XUANTIE_NAME environment variable is not defined") + endif() +endif() + +# Set LLVM_SYSPATH from environment variable +if(NOT DEFINED LLVM_SYSPATH) + if(DEFINED ENV{LLVM_SYSPATH}) + set(LLVM_SYSPATH $ENV{LLVM_SYSPATH}) + else() + message(FATAL_ERROR "LLVM_SYSPATH environment variable is not defined") + endif() +endif() + +if(NOT DEFINED TX8_HOME) + if(DEFINED ENV{TX8_HOME}) + set(TX8_HOME $ENV{TX8_HOME}) + else() + message(FATAL_ERROR "TX8_HOME environment variable is not defined") + endif() +endif() + +# Project name and version +project(VendorRuntime LANGUAGES CXX C) + +# Define RISC-V target triple +set(RISCV_TRIPLE "riscv64-unknown-elf") +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv) +set(CMAKE_C_COMPILER ${LLVM_SYSPATH}/bin/clang) +set(CMAKE_CXX_COMPILER ${LLVM_SYSPATH}/bin/clang++) + +# Define standard include directories +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/${TARGET}) +include_directories(${TX8_HOME}/include) +include_directories(${TX8_HOME}/${XUANTIE_NAME}/riscv64-unknown-elf/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +# Set build type default +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type (default Release)" FORCE) +endif() + +# Library name: vr stands for Vendor Runtime +set(VENDOR_RUNTIME_LIB vr) + +# Collect all source files from the vendor directory +file(GLOB_RECURSE VENDOR_SOURCES lib/${TARGET}/*.c) + +# Define RISC-V specific compile options +set(RISCV_COMPILE_OPTIONS + --target=${RISCV_TRIPLE} + -march=rv64gc + -mabi=lp64d + -mcmodel=medany +) + +# Add the library target +add_library(${VENDOR_RUNTIME_LIB} STATIC ${VENDOR_SOURCES}) + +# Apply RISC-V specific settings to our target +target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${RISCV_COMPILE_OPTIONS}) +target_link_options(${VENDOR_RUNTIME_LIB} PRIVATE --target=${RISCV_TRIPLE}) + +# Setup compiler and environment for RISC-V compilation +if(CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # Use the existing Clang installation with target triple + message(STATUS "Using Clang with RISC-V target triple") +else() + # Override compiler paths if using explicit RISC-V toolchain + message(STATUS "Setting explicit RISC-V compiler from LLVM_SYSPATH") + + foreach(source ${VENDOR_SOURCES}) + if(source MATCHES "\\.(c)$") + set_source_files_properties(${source} PROPERTIES + COMPILE_FLAGS "-xc --target=${RISCV_TRIPLE}" + LANGUAGE C) + elseif(source MATCHES "\\.(cpp)$") + set_source_files_properties(${source} PROPERTIES + COMPILE_FLAGS "-xc++ --target=${RISCV_TRIPLE}" + LANGUAGE CXX) + endif() + endforeach() + + # Set compiler launch commands for the target + add_custom_command(TARGET ${VENDOR_RUNTIME_LIB} PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Building ${VENDOR_RUNTIME_LIB} for RISC-V target" + ) +endif() diff --git a/third_party/tsingmicro/crt/README.md b/third_party/tsingmicro/crt/README.md new file mode 100644 index 000000000..3750c3c7c --- /dev/null +++ b/third_party/tsingmicro/crt/README.md @@ -0,0 +1,2 @@ +This folder contains the low level API implementation for various ML +controller or accelerator. diff --git a/third_party/tsingmicro/crt/gcc_flash_smartl.ld b/third_party/tsingmicro/crt/gcc_flash_smartl.ld new file mode 100644 index 000000000..6786fb002 --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_flash_smartl.ld @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +/****************************************************************************** + * @file gcc_csky.ld + * @brief csky linker file + * @version V1.0 + * @date 02. June 2017 + ******************************************************************************/ +MEMORY +{ + ISRAM : ORIGIN = 0x00000000 , LENGTH = 0x20000 /* ISRAM 128KB*/ + DSRAM : ORIGIN = 0x20000000 , LENGTH = 0x80000 /* DSRAM 512KB*/ +} + +__min_heap_size = 0x200; +PROVIDE (__ram_end = 0x20020000); +PROVIDE (__heap_end = __ram_end); + +REGION_ALIAS("REGION_TEXT", ISRAM); +REGION_ALIAS("REGION_RODATA", ISRAM); +REGION_ALIAS("REGION_DATA", DSRAM); +REGION_ALIAS("REGION_BSS", DSRAM); + +ENTRY(Reset_Handler) +SECTIONS +{ + .text : { + . = ALIGN(0x8) ; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN (0x4) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x10) ; + __etext = . ; + } > REGION_TEXT + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata*) + *(.srodata*) + . = ALIGN(0x8) ; + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + . = ALIGN(0x8) ; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8); + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8); + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > REGION_RODATA + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data*) + *(.gnu.linkonce.d*) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > REGION_DATA AT > REGION_RODATA + ._ram_code : { + . = ALIGN(0x8) ; + __ram_code_start__ = .; + *(.ram.code*) + . = ALIGN(0x8) ; + __ram_code_end__ = .; + } > REGION_DATA AT > REGION_RODATA + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > REGION_BSS AT > REGION_BSS + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = .; + . += __min_heap_size; + . = ALIGN(0x8) ; + } > REGION_BSS AT > REGION_BSS +} diff --git a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld new file mode 100644 index 000000000..db5c48fc0 --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +MEMORY +{ + DRAM : ORIGIN = 0x50000000, LENGTH = 0x100000 /* on-chip DRAM 1*1MB */ +} + +__min_heap_size = 0x200; +PROVIDE (__ram_end = 0x50100000 - 0x8); +PROVIDE (__heap_end = __ram_end); + +REGION_ALIAS("REGION_TEXT", DRAM); +REGION_ALIAS("REGION_RODATA", DRAM); +REGION_ALIAS("REGION_DATA", DRAM); +REGION_ALIAS("REGION_BSS", DRAM); + +ENTRY(Reset_Handler) +SECTIONS +{ + .text : { + . = ALIGN(0x8) ; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text) + *(.text*) + *(.text.*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN(0x8) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x8) ; + __etext = . ; + } > REGION_TEXT + .gcc_except_table : ONLY_IF_RO { + *(.gcc_except_table .gcc_except_table.*) + } > REGION_TEXT + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata) + *(.rodata1) + *(.rodata*) + *(.rodata.*) + *(.rodata.str1.4) + *(.srodata*) + . = ALIGN(0x8) ; + + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8) ; + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8) ; + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > REGION_RODATA + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data) + *(.data*) + *(.data1) + *(.data.*) + *(.gnu.linkonce.d*) + *(.data1) + *(.gcc_except_table) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > REGION_DATA + .gcc_except_table : ONLY_IF_RW { + *(.gcc_except_table .gcc_except_table.*) + __edata = .; + __data_end__ = .; + } > REGION_DATA + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > REGION_BSS + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = .; + . += __min_heap_size; + . = ALIGN(0x8) ; + } > REGION_BSS +} diff --git a/third_party/tsingmicro/crt/gcc_tx8_smarth.ld b/third_party/tsingmicro/crt/gcc_tx8_smarth.ld new file mode 100644 index 000000000..eb2aacb2a --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_tx8_smarth.ld @@ -0,0 +1,279 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +/****************************************************************************** + * @file gcc_csky.ld + * @brief csky linker file + * @version V1.0 + * @date 02. June 2017 + ******************************************************************************/ +MEMORY +{ + mem0 (rwx) : ORIGIN = 0x00000000, LENGTH = (20*1024*1024) +} + +REGION_ALIAS("r", mem0); +REGION_ALIAS("w", mem0); +REGION_ALIAS("x", mem0); + +ENTRY(Reset_Handler) +SECTIONS +{ + +.text.startup 0x0:{ + . = ALIGN(0x8) ; + KEEP(*startup.o(*.text.startup)) + *(.text.startup) + } > x + + .text : { + . = ALIGN(0x8) ; + __ram_code_start__ = .; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text) + *(.vectors) + *(.text*) + *(.text.*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN(0x8) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x8) ; + __etext = . ; + __ram_code_end__ = .; + } > x + + .gcc_except_table : ONLY_IF_RO { + *(.gcc_except_table .gcc_except_table.*) + } > x + + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata) + *(.rodata1) + *(.rodata*) + *(.rodata.*) + *(.rodata.str1.4) + *(.srodata*) + . = ALIGN(0x8) ; + + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8) ; + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8) ; + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > r + + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data) + *(.data*) + *(.data1) + *(.data.*) + *(.gnu.linkonce.d*) + *(.data1) + *(.gcc_except_table) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > w AT> r + + .gcc_except_table : ONLY_IF_RW { + *(.gcc_except_table .gcc_except_table.*) + __edata = .; + __data_end__ = .; + } > w AT> r + + .rela.dyn : { + . = ALIGN(0x8) ; + __rel_dyn_start = .; + *(.rela*) + __rel_dyn_end = .; + } + .dynsym : { + . = ALIGN(0x8) ; + __dyn_sym_start = .; + *(.dynsym) + __dyn_sym_end = .; + } + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > w + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = ABSOLUTE(.); + . = ORIGIN(w) + LENGTH(w); + __heap_end = ABSOLUTE(.); + + } > w +} diff --git a/third_party/tsingmicro/crt/include/Tx81/tx81.h b/third_party/tsingmicro/crt/include/Tx81/tx81.h new file mode 100644 index 000000000..b0af0b73d --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/tx81.h @@ -0,0 +1,33 @@ +//===----------------------- tx81.h ---------------------------*- C -*-----===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef CRT_TARGET_TX81_H +#define CRT_TARGET_TX81_H + +#include "instr_adapter.h" +#include "instr_def.h" +#include "lib_log.h" +#include +#include +#include + +enum MemorySpace : int32_t { + UNKNOWN = 0, + SPM = 1, + DDR = 2, +}; + +// Neural engine activate mode +enum ActFuncMode : int32_t { + None = 0, + ENRelu = 1, + ENLeakRelu = 2, +}; + +inline uint64_t spm_print_offset(uint64_t addr) { + return (uint64_t)addr + 0x030400000; +} +#endif // CRT_TARGET_TX81_H diff --git a/third_party/tsingmicro/crt/lib/Tx81/abs.c b/third_party/tsingmicro/crt/lib/Tx81/abs.c new file mode 100644 index 000000000..3b79549b1 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/abs.c @@ -0,0 +1,32 @@ +//===------------------------- abs.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::AbsVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AbsVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AbsVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmax.c b/third_party/tsingmicro/crt/lib/Tx81/argmax.c new file mode 100644 index 000000000..a982f8ff9 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/argmax.c @@ -0,0 +1,33 @@ +//===------------------------ argmax.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArgMax see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __ArgMax(uint64_t *src, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->ArgMax(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmin.c b/third_party/tsingmicro/crt/lib/Tx81/argmin.c new file mode 100644 index 000000000..856854d3b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/argmin.c @@ -0,0 +1,33 @@ +//===------------------------ argmin.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArgMin see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __ArgMin(uint64_t *src, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->ArgMin(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/arith.c b/third_party/tsingmicro/crt/lib/Tx81/arith.c new file mode 100644 index 000000000..a0040e82e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/arith.c @@ -0,0 +1,232 @@ +//===------------------------ arith.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArithOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AddVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AddVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->SubVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MulVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MulVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __DivVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->DivVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __AddVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AddVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __SubVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->SubVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MulVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MulVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __DivVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->DivVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MaxVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE reserved, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MaxVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + reserved, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MinVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE reserved, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MinVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + reserved, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c new file mode 100644 index 000000000..eef93082f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c @@ -0,0 +1,31 @@ +//===------------------------ bf16_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c new file mode 100644 index 000000000..957fb60c4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c @@ -0,0 +1,31 @@ +//===------------------------ bf16_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c new file mode 100644 index 000000000..930c3f395 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c @@ -0,0 +1,32 @@ +//===------------------------ bf16_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c new file mode 100644 index 000000000..b9b5b38da --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c @@ -0,0 +1,32 @@ +//===------------------------ bf16_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c new file mode 100644 index 000000000..194c104b4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c @@ -0,0 +1,31 @@ +//===------------------------ bf16_int8.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c new file mode 100644 index 000000000..e5a60ca55 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c @@ -0,0 +1,31 @@ +//===------------------------ bf16_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BF16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c new file mode 100644 index 000000000..43ddf791c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c @@ -0,0 +1,39 @@ +//===------------------------ bilinear.c ----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Bilinear see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Bilinear(uint64_t *src, uint64_t *dst, uint16_t src_n, uint16_t src_h, + uint16_t src_w, uint16_t src_c, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Bilinear(&inst, (uint64_t)src, (uint64_t)dst, shape1, shape2, + (src_w - 1) / (dst_w - 1), (src_h - 1) / (dst_h - 1), + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c new file mode 100644 index 000000000..0a9344213 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c @@ -0,0 +1,34 @@ +//===------------------------ bit2fp.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Bit2Fp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Bit2Fp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/common.c b/third_party/tsingmicro/crt/lib/Tx81/common.c new file mode 100644 index 000000000..c1b6085d9 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/common.c @@ -0,0 +1,19 @@ +//===----------------------- common.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Implement common helper functions in this file. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// WORKAROUND for undefined symbols in libkcorert.a +int main(int argc, char **argv) { return 0; } + +int get_app_version() { return 1; } + +int nvram_get_val() { return 1; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/concat.c b/third_party/tsingmicro/crt/lib/Tx81/concat.c new file mode 100644 index 000000000..8bd489bff --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/concat.c @@ -0,0 +1,40 @@ +//===------------------------ concat.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Concat see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Concat(uint64_t *src1, uint16_t src1_n, uint16_t src1_h, uint16_t src1_w, + uint16_t src1_c, uint64_t *src2, uint16_t src2_n, uint16_t src2_h, + uint16_t src2_w, uint16_t src2_c, uint64_t *dst, uint16_t dst_n, + uint16_t dst_h, uint16_t dst_w, uint16_t dst_c, uint32_t dim, + uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src1_n, src1_h, src1_w, src1_c}; + Data_Shape shape2 = {src2_n, src2_h, src2_w, src2_c}; + Data_Shape shape3 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Concat(&inst, (uint64_t)src1, shape1, (uint64_t)src2, shape2, + (uint64_t)dst, shape3, dim, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/conv.c b/third_party/tsingmicro/crt/lib/Tx81/conv.c new file mode 100644 index 000000000..d4f7dcae2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/conv.c @@ -0,0 +1,66 @@ +//===------------------------ conv.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmConv, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Conv(int64_t opType, int64_t *srcAct, int64_t *srcActDims, + int64_t *weight, int64_t *weightDims, bool enBias, int64_t *bias, + bool enNegScale, int64_t *negScale, bool enPosScale, + int64_t *posScale, bool enSparse, int64_t *sparse, bool enPsum, + int64_t *psum, int64_t *pads, int64_t *unpads, int64_t *strides, + int64_t *dilations, bool enLeakyRelu, int64_t srcActFmt, + int64_t weightFmt, int64_t dstFmt, int64_t *dst, int64_t *dstDims) { + // Create convolution command buffer. + TsmConv *conv = TsmNewConv(); + TsmNeInstr inst = {I_NEUR, + { + 0, + }, + { + 0, + }}; + + // Convert to nhwc format + Data_Shape shape = {(uint16_t)srcActDims[0], (uint16_t)srcActDims[1], + (uint16_t)srcActDims[2], (uint16_t)srcActDims[3]}; + + Data_Shape wshape = {(uint16_t)weightDims[0], (uint16_t)weightDims[1], + (uint16_t)weightDims[2], (uint16_t)weightDims[3]}; + + Data_Shape dstShape = {(uint16_t)dstDims[0], (uint16_t)dstDims[1], + (uint16_t)dstDims[2], (uint16_t)dstDims[3]}; + + conv->AddInput(&inst, (int64_t)srcAct, shape, (Data_Format)srcActFmt); + conv->AddWeight(&inst, (uint64_t)weight, wshape, (Data_Format)weightFmt); + conv->AddBias(&inst, enBias, (uint64_t)bias); + conv->AddOutput(&inst, (uint64_t)dst, dstShape, (Data_Format)dstFmt); + conv->SetOpType(&inst, opType); + conv->SetNegativeAxisScale(&inst, enNegScale, (uint64_t)negScale); + conv->SetPositiveAxisScale(&inst, enPosScale, (uint64_t)posScale); + conv->SetSparse(&inst, enSparse, (uint64_t)sparse); + // FIXME: Should we have psum format instead? + conv->SetPsum(&inst, enPsum, (uint64_t)psum, (Data_Format)dstFmt); + conv->SetPads(&inst, pads[0], pads[1], pads[2], pads[3]); + conv->SetUnPads(&inst, unpads[0], unpads[1], unpads[2], unpads[3]); + conv->SetKernelStrides(&inst, strides[0], strides[1], strides[2], strides[3]); + conv->SetDilations(&inst, dilations[0], dilations[1]); + if (enLeakyRelu) + conv->EnableLeakyRelu(&inst); + else + conv->EnableRelu(&inst); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConv(conv); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/cos.c b/third_party/tsingmicro/crt/lib/Tx81/cos.c new file mode 100644 index 000000000..0ea6f096b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/cos.c @@ -0,0 +1,32 @@ +//===------------------------ cos.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Cos see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Cos(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Cos(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/count.c b/third_party/tsingmicro/crt/lib/Tx81/count.c new file mode 100644 index 000000000..9f37bdaa0 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/count.c @@ -0,0 +1,33 @@ +//===------------------------ count.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Count see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Count(uint64_t *src, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->Count(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/exp.c b/third_party/tsingmicro/crt/lib/Tx81/exp.c new file mode 100644 index 000000000..37b52b1f0 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/exp.c @@ -0,0 +1,32 @@ +//===------------------------ exp.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Exp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Exp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Exp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/explp.c b/third_party/tsingmicro/crt/lib/Tx81/explp.c new file mode 100644 index 000000000..e917ae61e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/explp.c @@ -0,0 +1,32 @@ +//===------------------------ explp.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Explp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Explp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Explp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c new file mode 100644 index 000000000..2fe6b4d6d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ fp16_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c new file mode 100644 index 000000000..8493d00ee --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c @@ -0,0 +1,31 @@ +//===------------------------ fp16_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c new file mode 100644 index 000000000..f6a34594b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c @@ -0,0 +1,32 @@ +//===------------------------ fp16_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c new file mode 100644 index 000000000..9cd30e5ef --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c @@ -0,0 +1,32 @@ +//===------------------------ fp16_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c new file mode 100644 index 000000000..a099d8948 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c @@ -0,0 +1,32 @@ +//===------------------------ fp16_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c new file mode 100644 index 000000000..f36fbe943 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c @@ -0,0 +1,31 @@ +//===------------------------ fp16_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c new file mode 100644 index 000000000..c18f95f7c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c new file mode 100644 index 000000000..0c27fba95 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_fp16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c new file mode 100644 index 000000000..5ec2e93c4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c new file mode 100644 index 000000000..9d8e0622c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c new file mode 100644 index 000000000..23d4306f4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c new file mode 100644 index 000000000..ecf2436ae --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c @@ -0,0 +1,32 @@ +//===------------------------ fp32_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->FP32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c new file mode 100644 index 000000000..032290c1f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c @@ -0,0 +1,40 @@ +//===------------------------ gatherscatter.c -----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::GatherScatter see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, + uint32_t src_s0, uint32_t src_i0, uint32_t src_s1, + uint32_t src_i1, uint32_t src_s2, uint32_t src_i2, + uint32_t dst_s0, uint32_t dst_i0, uint32_t dst_s1, + uint32_t dst_i1, uint32_t dst_s2, uint32_t dst_i2) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + St_StrideIteration src_si = {src_s0, src_i0, src_s1, src_i1, src_s2, src_i2}; + St_StrideIteration dst_si = {dst_s0, dst_i0, dst_s1, dst_i1, dst_s2, dst_i2}; + + cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, size, &src_si, + &dst_si); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/gemm.c b/third_party/tsingmicro/crt/lib/Tx81/gemm.c new file mode 100644 index 000000000..ff0c5114f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/gemm.c @@ -0,0 +1,59 @@ +//===------------------------ gemm.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmGemm, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *dst, + int32_t *dims, bool enPsum, int64_t *psum, bool enTransA, + bool enTransB, int64_t batchSizeA, int64_t batchSizeB, + int32_t reluMode, bool enBias, bool enNegScale, int64_t *negScale, + bool enPosScale, int64_t *posScale, int64_t srcFmt, + int64_t dstFmt) { + // Create gemm command buffer. + TsmGemm *gemm = TsmNewGemm(); + TsmNeInstr inst = {I_NEUR, + { + 0, + }, + { + 0, + }}; + + gemm->AddInput(&inst, (uint64_t)srcA, (uint64_t)srcB, (Data_Format)srcFmt); + gemm->ConfigMKN(&inst, (uint32_t)dims[0], (uint32_t)dims[1], + (uint32_t)dims[2]); + gemm->AddOutput(&inst, (uint64_t)dst, (Data_Format)dstFmt); + gemm->SetPsum(&inst, enPsum, (uint64_t)psum, (Data_Format)dstFmt); + gemm->SetTransflag(&inst, (uint8_t)enTransA, (uint8_t)enTransB); + // TODO: + // gemm->SetQuant(); + gemm->ConfigBatch(&inst, (uint32_t)batchSizeA, (uint32_t)batchSizeB); + gemm->AddBias(&inst, enBias, (uint64_t)srcBias); + gemm->SetNegativeAxisScale(&inst, enNegScale, (uint64_t)negScale); + gemm->SetPositiveAxisScale(&inst, enPosScale, (uint64_t)posScale); + switch (reluMode) { + case ENRelu: + gemm->EnableRelu(&inst); + break; + case ENLeakRelu: + gemm->EnableLeakyRelu(&inst); + break; + default: + break; + } + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteGemm(gemm); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/img2col.c b/third_party/tsingmicro/crt/lib/Tx81/img2col.c new file mode 100644 index 000000000..a578e1351 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/img2col.c @@ -0,0 +1,42 @@ +//===------------------------ img2col.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Img2col see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Img2col(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint64_t src_elem_num, + uint64_t dst_elem_num, uint16_t swr_n, uint16_t swr_h, + uint16_t swr_w, uint16_t swr_c, uint16_t pdr_n, uint16_t pdr_h, + uint16_t pdr_w, uint16_t pdr_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape3 = {swr_n, swr_h, swr_w, swr_c}; + Data_Shape shape4 = {pdr_n, pdr_h, pdr_w, pdr_c}; + cmd->Img2col(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + src_elem_num, dst_elem_num, shape3, shape4, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c new file mode 100644 index 000000000..213681aad --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ int16_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c new file mode 100644 index 000000000..a23297033 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c @@ -0,0 +1,31 @@ +//===------------------------ int16_fp16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT16_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c new file mode 100644 index 000000000..9e5ba8e3f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c @@ -0,0 +1,32 @@ +//===------------------------ int16_fp32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c new file mode 100644 index 000000000..1a08f227f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c @@ -0,0 +1,32 @@ +//===------------------------ int16_tf32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c new file mode 100644 index 000000000..5b9949719 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ int32_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c new file mode 100644 index 000000000..e9c9f14ee --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c @@ -0,0 +1,33 @@ +//===------------------------ int32_fp16.cpp +//-------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c new file mode 100644 index 000000000..6fb7778f4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c @@ -0,0 +1,32 @@ +//===------------------------ int32_fp32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c new file mode 100644 index 000000000..6c65087da --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c @@ -0,0 +1,32 @@ +//===------------------------ int32_tf32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c new file mode 100644 index 000000000..bce9dfa27 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ int8_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_BF16(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT8_BF16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c new file mode 100644 index 000000000..94061ebc4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c @@ -0,0 +1,32 @@ +//===------------------------ int8_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_FP16(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT8_FP16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c new file mode 100644 index 000000000..b2be9df3c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c @@ -0,0 +1,32 @@ +//===------------------------ int8_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_FP32(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT8_FP32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c new file mode 100644 index 000000000..3fb5fcfae --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c @@ -0,0 +1,32 @@ +//===------------------------ int8_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_TF32(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->INT8_TF32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c new file mode 100644 index 000000000..c1cdb81f0 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c @@ -0,0 +1,34 @@ +//===------------------------ leakyrelu.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Leakyrelu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Leakyrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Leakyrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/ln.c b/third_party/tsingmicro/crt/lib/Tx81/ln.c new file mode 100644 index 000000000..01776e243 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/ln.c @@ -0,0 +1,32 @@ +//===------------------------ ln.c ----------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Ln see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Ln(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Ln(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/log2.c b/third_party/tsingmicro/crt/lib/Tx81/log2.c new file mode 100644 index 000000000..8dcdfc3e8 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/log2.c @@ -0,0 +1,32 @@ +//===------------------------ log2.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Log2 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Log2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Log2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/logic.c b/third_party/tsingmicro/crt/lib/Tx81/logic.c new file mode 100644 index 000000000..62e36b694 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/logic.c @@ -0,0 +1,78 @@ +//===------------------------ logic.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::LogicOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AndVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AndVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} + +void __OrVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->OrVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} + +void __XorVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->XorVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut16.c b/third_party/tsingmicro/crt/lib/Tx81/lut16.c new file mode 100644 index 000000000..d4e8bea10 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/lut16.c @@ -0,0 +1,35 @@ +//===------------------------ lut16.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Lut16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Lut16(uint64_t *src, uint64_t *dst, uint64_t *lut16, + uint32_t src_elem_count, uint32_t lut_elem_count) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->Lut16(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut16, + src_elem_count, lut_elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut32.c b/third_party/tsingmicro/crt/lib/Tx81/lut32.c new file mode 100644 index 000000000..16f6df9ba --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/lut32.c @@ -0,0 +1,35 @@ +//===------------------------ lut32.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Lut32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Lut32(uint64_t *src, uint64_t *dst, uint64_t *lut32, + uint32_t src_elem_count, uint32_t lut_elem_count) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->Lut32(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut32, + src_elem_count, lut_elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c new file mode 100644 index 000000000..05d7989e7 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c @@ -0,0 +1,31 @@ +//===------------------------ mask_move.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::MaskMoveOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __MaskMove(uint64_t *src, uint64_t *target, uint32_t elem_count, + uint64_t *mask, int32_t fmt) { + TsmMaskDataMove *move = TsmNewMaskDataMove(); + TsmMaskDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + move->MaskMove(&inst, (uint64_t)src, (uint64_t)mask, (uint64_t)target, + elem_count, (Data_Format)fmt); + + TsmExecute(&inst); + + TsmDeleteMaskDataMove(move); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/memset.c b/third_party/tsingmicro/crt/lib/Tx81/memset.c new file mode 100644 index 000000000..f21651138 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/memset.c @@ -0,0 +1,47 @@ +//===------------------------ memset.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Memset see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Memset(uint64_t *dst, uint32_t value, uint32_t elem_count, uint32_t s0, + uint32_t i0, uint32_t s1, uint32_t i1, uint32_t s2, uint32_t i2, + uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + // TODO: Use real stride and iteration, now accumulate all data to elem_count + int stride0 = 0; + int stride1 = 0; + int stride2 = 0; + + int iteration0 = 1; + int iteration1 = 1; + int iteration2 = 1; + + elem_count *= i0 * i1 * i2; + + St_StrideIteration si = {stride0, iteration0, stride1, + iteration1, stride1, iteration2}; + cmd->Memset(&inst, (uint64_t)dst, value, elem_count, &si, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/mirror.c b/third_party/tsingmicro/crt/lib/Tx81/mirror.c new file mode 100644 index 000000000..543fec808 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/mirror.c @@ -0,0 +1,37 @@ +//===------------------------ mirror.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Mirror see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Mirror(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Mirror(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c new file mode 100644 index 000000000..ed916cb60 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c @@ -0,0 +1,37 @@ +//===------------------------ nchw2nhwc.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Nchw2nhwc see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Nchw2nhwc(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Nchw2nhwc(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c new file mode 100644 index 000000000..932b71599 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c @@ -0,0 +1,37 @@ +//===------------------------ nhwc2nchw.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Nhwc2nchw see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Nhwc2nchw(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Nhwc2nchw(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/pad.c b/third_party/tsingmicro/crt/lib/Tx81/pad.c new file mode 100644 index 000000000..3ccde1221 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/pad.c @@ -0,0 +1,39 @@ +//===------------------------ pad.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Pad see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Pad(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t pad_n, uint16_t pad_h, + uint16_t pad_w, uint16_t pad_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape3 = {pad_n, pad_h, pad_w, pad_c}; + cmd->Pad(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, shape3, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/pow2.c b/third_party/tsingmicro/crt/lib/Tx81/pow2.c new file mode 100644 index 000000000..9ed3fa0ae --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/pow2.c @@ -0,0 +1,32 @@ +//===------------------------ pow2.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Pow2 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Pow2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Pow2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/randgen.c b/third_party/tsingmicro/crt/lib/Tx81/randgen.c new file mode 100644 index 000000000..85382f17d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/randgen.c @@ -0,0 +1,35 @@ +//===------------------------ randgen.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RandGen see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __RandGen(uint64_t *src0, uint64_t *src1, uint64_t *dst0, uint64_t *dst1, + uint64_t *dst2, uint32_t src_elem_num, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; + + cmd->RandGen(&inst, *src0, *src1, *dst0, *dst1, *dst2, src_elem_num, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rdma.c b/third_party/tsingmicro/crt/lib/Tx81/rdma.c new file mode 100644 index 000000000..f000df052 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rdma.c @@ -0,0 +1,41 @@ +//===------------------------ rdma.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rdma, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, + int shape_c, int stride_n, int stride_h, int stride_w, + uint32_t fmt) { + // Dynamic shape, kernel implementation will cause shape equal to 0 + if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + return; + + // Create gemm command buffer. + TsmRdma *rdma = TsmNewRdma(); + TsmRdmaInstr inst = {I_RDMA, + { + 0, + }, + { + 0, + }}; + + rdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + rdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, + shape_h, stride_n, shape_n); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRdma(rdma); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/reduce.c b/third_party/tsingmicro/crt/lib/Tx81/reduce.c new file mode 100644 index 000000000..ebaeceb8e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/reduce.c @@ -0,0 +1,103 @@ +//===---------------------- reduce.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmReduce, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __ReduceSum(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceSum(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceAvg(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceAvg(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceMax(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceMax(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceMin(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceMin(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/relation.c b/third_party/tsingmicro/crt/lib/Tx81/relation.c new file mode 100644 index 000000000..10069deb6 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/relation.c @@ -0,0 +1,144 @@ +//===------------------------ relation.c-----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RelationOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BoolEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolUnEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolUnEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessThenVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/relu.c b/third_party/tsingmicro/crt/lib/Tx81/relu.c new file mode 100644 index 000000000..ccaf77ec7 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/relu.c @@ -0,0 +1,32 @@ +//===------------------------ relu.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Relu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Relu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Relu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c new file mode 100644 index 000000000..1b068458b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c @@ -0,0 +1,37 @@ +//===------------------------ rotate180.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate180 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate180(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate180(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c new file mode 100644 index 000000000..15d84f28f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c @@ -0,0 +1,37 @@ +//===------------------------ rotate270.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate270 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate270(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate270(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c new file mode 100644 index 000000000..15c87d429 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c @@ -0,0 +1,37 @@ +//===------------------------ rotate90.c ----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate90 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate90(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate90(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c new file mode 100644 index 000000000..d0468966d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c @@ -0,0 +1,34 @@ +//===------------------------ rsqrt.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RsqrtVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __RsqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->RsqrtVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c new file mode 100644 index 000000000..e9d67dfee --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c @@ -0,0 +1,34 @@ +//===------------------------ satrelu.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Satrelu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Satrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Satrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c new file mode 100644 index 000000000..d92a42e22 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c @@ -0,0 +1,34 @@ +//===------------------------ sigmoid.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Sigmoid see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Sigmoid(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Sigmoid(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sin.c b/third_party/tsingmicro/crt/lib/Tx81/sin.c new file mode 100644 index 000000000..7bb37c6d2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sin.c @@ -0,0 +1,32 @@ +//===------------------------ Sin.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Sin see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Sin(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Sin(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/softplus.c b/third_party/tsingmicro/crt/lib/Tx81/softplus.c new file mode 100644 index 000000000..d384cf701 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/softplus.c @@ -0,0 +1,35 @@ +//===------------------------ softplus.cpp +//------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Softplus see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Softplus(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Softplus(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sqrt.c b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c new file mode 100644 index 000000000..5701513cb --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c @@ -0,0 +1,33 @@ +//===------------------------ sqrt.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::SqrtVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __SqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->SqrtVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tanh.c b/third_party/tsingmicro/crt/lib/Tx81/tanh.c new file mode 100644 index 000000000..e91cf229d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tanh.c @@ -0,0 +1,32 @@ +//===------------------------ tanh.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Tanh see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Tanh(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->Tanh(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c new file mode 100644 index 000000000..d141faa44 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c @@ -0,0 +1,37 @@ +//===------------------------ tensornorm.c --------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TensorNorm see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TensorNorm(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->TensorNom(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c new file mode 100644 index 000000000..6b9e2b5d3 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c @@ -0,0 +1,32 @@ +//===------------------------ tf32_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c new file mode 100644 index 000000000..9e78a4bda --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c @@ -0,0 +1,31 @@ +//===------------------------ tf32_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c new file mode 100644 index 000000000..92550fb2a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c @@ -0,0 +1,31 @@ +//===------------------------ tf32_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c new file mode 100644 index 000000000..b6e2951c2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c @@ -0,0 +1,32 @@ +//===------------------------ tf32_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c new file mode 100644 index 000000000..de1ae6725 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c @@ -0,0 +1,32 @@ +//===------------------------ tf32_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c new file mode 100644 index 000000000..4c60fbf98 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c @@ -0,0 +1,32 @@ +//===------------------------ tf32_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->TF32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/transpose.c b/third_party/tsingmicro/crt/lib/Tx81/transpose.c new file mode 100644 index 000000000..54e2ee584 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/transpose.c @@ -0,0 +1,37 @@ +//===------------------------ transpose.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Transpose see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Transpose(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Transpose(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/wdma.c b/third_party/tsingmicro/crt/lib/Tx81/wdma.c new file mode 100644 index 000000000..93bfe6e89 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/wdma.c @@ -0,0 +1,43 @@ +//===------------------------ wdma.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Wdma, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, + int shape_c, int stride_n, int stride_h, int stride_w, + uint32_t fmt) { + + // Dynamic shape, kernel implementation will cause shape equal to 0 + if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + return; + + // Create gemm command buffer. + TsmWdma *wdma = TsmNewWdma(); + TsmWdmaInstr inst = {I_WDMA, + { + 0, + }, + { + 0, + }}; + + wdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + + wdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, + shape_h, stride_n, shape_n); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteWdma(wdma); +} diff --git a/third_party/tsingmicro/examples/bare_matmul.py b/third_party/tsingmicro/examples/bare_matmul.py new file mode 100644 index 000000000..84b9c9a87 --- /dev/null +++ b/third_party/tsingmicro/examples/bare_matmul.py @@ -0,0 +1,52 @@ +# this is a benchmark which multiplies square matrices with maximum block size +# to check the performance of tl.dot operation + +import torch +import triton +import triton.language as tl +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): + pid_x = tl.program_id(0) # block row id + pid_y = tl.program_id(1) # block column id + + offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) + y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) + + z = tl.dot(x, y) + + tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) + + +# @benchmark.measure() +def bench_matmul(N, provider): + device = 'cpu' + dtype = torch.float32 + a = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) + b = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) + # a = torch.randn((N, N), device=device, dtype=dtype) + # b = torch.randn((N, N), device=device, dtype=dtype) + c = torch.empty((N, N), device=device, dtype=dtype) + if provider == 'torch' or provider == 'test': + c_ref = torch.matmul(a, b) + # print("====cref:",c_ref) + if provider == 'triton' or provider == 'test': + bare_matmul[(1, )](a, b, c, N, N, N, N) + if provider == 'test': + torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0) + print("expected", c_ref) + print("actual", c) + print("======test====") + + +if __name__ == "__main__": + # benchmark.select_cpu_backend() + for provider in ['test']: + bench_matmul(16, provider) diff --git a/third_party/tsingmicro/examples/benchmark.py b/third_party/tsingmicro/examples/benchmark.py new file mode 100644 index 000000000..4a2284dd2 --- /dev/null +++ b/third_party/tsingmicro/examples/benchmark.py @@ -0,0 +1,65 @@ +import time +import numpy as np +from functools import wraps +import triton + +# Unfortunately, we can't use triton.testing.perf_report and triton.testing.do_bench for CPU backend because +# they are very specific to cuda + + +def measure(repeats=20, percentiles=(), timers={'Wall': time.perf_counter, 'CPU': time.process_time}): + """ + Decorator to benchmark a function. + + Parameters: + - repeats (int): The number of times the function should be executed for each set of parameters. + - percentiles (tuple): The percentiles to compute on the execution times (e.g., (50, 90, 99)). + - timers (dict): A dictionary where keys are timer names (e.g., 'Wall', 'CPU') and values are timer functions + that measure elapsed time. By default: + * 'Wall': Uses time.perf_counter for high-resolution wall-clock time. + * 'CPU': Uses time.process_time for CPU time spent by the process. + + Returns: + - A decorated function that prints: + * Average execution time. + * Standard deviation time. + * Minimum and maximum times. + * Computed percentiles for each timer. + """ + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + print(f"{func.__name__}{args} {kwargs}, {repeats} times, all results in seconds") + times = {} + for t, _ in timers.items(): + times[t] = [] + + for _ in range(repeats): + starts = {} + for t, f in timers.items(): + starts[t] = f() + + result = func(*args, **kwargs) + + for t, f in timers.items(): + times[t].append(f() - starts[t]) + + for t, _ in timers.items(): + average_time = np.mean(times[t]) + min_time = np.min(times[t]) + max_time = np.max(times[t]) + computed_percentiles = np.percentile(times[t], percentiles) + std_dev_time = np.std(times[t]) + + print(f"{t}: Avg={average_time:.6f}, min={min_time:.6f}, std={std_dev_time:.6f},", end=" ") + for p, value in zip(percentiles, computed_percentiles): + print(f"{p}pp={value:.6f},", end=" ") + print(f"max={max_time:.6f}") + + return result + + return wrapper + + return decorator diff --git a/third_party/tsingmicro/examples/test_vec_add.py b/third_party/tsingmicro/examples/test_vec_add.py new file mode 100644 index 000000000..b75f5aa42 --- /dev/null +++ b/third_party/tsingmicro/examples/test_vec_add.py @@ -0,0 +1,90 @@ +import torch + +import triton +import triton.language as tl +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + # assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + + +def test(device): + # torch.manual_seed(0) + size = 1024 + x = torch.rand(size, device="cpu") + y = torch.rand(size, device="cpu") + output_torch = x + y + x = x.to(device) + y = y.to(device) + output_triton = add(x, y) + # TODO: need to check some conditions otherwise the code below does not make any difference for the test + print("expected", output_torch) + output_triton = output_triton.to("cpu") + print("actual", output_triton) + print(f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(output_torch - output_triton))}") + + +@benchmark.measure() +def bench_vecadd(size, provider): + a = torch.rand(size, device='cpu', dtype=torch.float32) + b = torch.rand(size, device='cpu', dtype=torch.float32) + if provider == 'torch': + a + b + if provider == 'triton': + a = a.to(DEVICE) + b = b.to(DEVICE) + add(a, b) + + +if __name__ == "__main__": + # test(DEVICE) + for X in [2**i for i in range(8, 25, 1)]: + for provider in ['torch', 'triton']: + bench_vecadd(X, provider) diff --git a/third_party/tsingmicro/include/CMakeLists.txt b/third_party/tsingmicro/include/CMakeLists.txt new file mode 100644 index 000000000..76c90b65b --- /dev/null +++ b/third_party/tsingmicro/include/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(triton-shared) +add_subdirectory(magic-kernel) +add_subdirectory(tsingmicro-tx81) +# The following 2 dialects are currently unused. +#add_subdirectory(magic-kernel-func) +#add_subdirectory(magic-kernel-instr) diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp new file mode 100644 index 000000000..87e47027f --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp @@ -0,0 +1,189 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "CRunnerUtils.h" +#include "Msan.h" + +#ifndef _WIN32 +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + defined(__DragonFly__) +#include +#else +#include +#endif +#include +#else +#include "malloc.h" +#endif // _WIN32 + +#include +#include +#include +#include +#include +#include + +#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS + +namespace { +template void stdSort(uint64_t n, V *p) { std::sort(p, p + n); } + +} // namespace + +// Small runtime support "lib" for vector.print lowering. +// By providing elementary printing methods only, this +// library can remain fully unaware of low-level implementation +// details of our vectors. Also useful for direct LLVM IR output. +extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); } +extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); } +extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } +extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); } +extern "C" void printString(char const *s) { fputs(s, stdout); } +extern "C" void printOpen() { fputs("( ", stdout); } +extern "C" void printClose() { fputs(" )", stdout); } +extern "C" void printComma() { fputs(", ", stdout); } +extern "C" void printNewline() { fputc('\n', stdout); } + +extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, + UnrankedMemRefType *dstArg) { + DynamicMemRefType src(*srcArg); + DynamicMemRefType dst(*dstArg); + + int64_t rank = src.rank; + MLIR_MSAN_MEMORY_IS_INITIALIZED(src.sizes, rank * sizeof(int64_t)); + + // Handle empty shapes -> nothing to copy. + for (int rankp = 0; rankp < rank; ++rankp) + if (src.sizes[rankp] == 0) + return; + + char *srcPtr = src.data + src.offset * elemSize; + char *dstPtr = dst.data + dst.offset * elemSize; + + if (rank == 0) { + memcpy(dstPtr, srcPtr, elemSize); + return; + } + + int64_t *indices = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *srcStrides = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *dstStrides = static_cast(alloca(sizeof(int64_t) * rank)); + + // Initialize index and scale strides. + for (int rankp = 0; rankp < rank; ++rankp) { + indices[rankp] = 0; + srcStrides[rankp] = src.strides[rankp] * elemSize; + dstStrides[rankp] = dst.strides[rankp] * elemSize; + } + + int64_t readIndex = 0, writeIndex = 0; + for (;;) { + // Copy over the element, byte by byte. + memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize); + // Advance index and read position. + for (int64_t axis = rank - 1; axis >= 0; --axis) { + // Advance at current axis. + auto newIndex = ++indices[axis]; + readIndex += srcStrides[axis]; + writeIndex += dstStrides[axis]; + // If this is a valid index, we have our next index, so continue copying. + if (src.sizes[axis] != newIndex) + break; + // We reached the end of this axis. If this is axis 0, we are done. + if (axis == 0) + return; + // Else, reset to 0 and undo the advancement of the linear index that + // this axis had. Then continue with the axis one outer. + indices[axis] = 0; + readIndex -= src.sizes[axis] * srcStrides[axis]; + writeIndex -= dst.sizes[axis] * dstStrides[axis]; + } + } +} + +/// Prints GFLOPS rating. +extern "C" void printFlops(double flops) { + fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9); +} + +/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC). +extern "C" double rtclock() { +#ifndef _WIN32 + struct timeval tp; + int stat = gettimeofday(&tp, nullptr); + if (stat != 0) + fprintf(stderr, "Error returning time from gettimeofday: %d\n", stat); + return (tp.tv_sec + tp.tv_usec * 1.0e-6); +#else + fprintf(stderr, "Timing utility not implemented on Windows\n"); + return 0.0; +#endif // _WIN32 +} + +extern "C" void *mlirAlloc(uint64_t size) { return malloc(size); } + +extern "C" void *mlirAlignedAlloc(uint64_t alignment, uint64_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#elif defined(__APPLE__) + // aligned_alloc was added in MacOS 10.15. Fall back to posix_memalign to also + // support older versions. + void *result = nullptr; + (void)::posix_memalign(&result, alignment, size); + return result; +#else + return aligned_alloc(alignment, size); +#endif +} + +extern "C" void mlirFree(void *ptr) { free(ptr); } + +extern "C" void mlirAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +extern "C" void *rtsrand(uint64_t s) { + // Standard mersenne_twister_engine seeded with s. + return new std::mt19937(s); +} + +extern "C" uint64_t rtrand(void *g, uint64_t m) { + std::mt19937 *generator = static_cast(g); + std::uniform_int_distribution distrib(0, m); + return distrib(*generator); +} + +extern "C" void rtdrand(void *g) { + std::mt19937 *generator = static_cast(g); + delete generator; +} + +#define IMPL_STDSORT(VNAME, V) \ + extern "C" void _mlir_ciface_stdSort##VNAME(uint64_t n, \ + StridedMemRefType *vref) { \ + assert(vref); \ + assert(vref->strides[0] == 1); \ + V *values = vref->data + vref->offset; \ + stdSort(n, values); \ + } +IMPL_STDSORT(I64, int64_t) +IMPL_STDSORT(F64, double) +IMPL_STDSORT(F32, float) +#undef IMPL_STDSORT + +#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h new file mode 100644 index 000000000..1e55ca923 --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h @@ -0,0 +1,482 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H +#define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H + +#ifdef _WIN32 +#ifndef MLIR_CRUNNERUTILS_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +// We are building this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#else +// We are using this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_CRUNNERUTILS_EXPORT +#else // _WIN32 +// Non-windows: use visibility attributes. +#define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default"))) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#endif // _WIN32 + +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for Vector type. +//===----------------------------------------------------------------------===// +namespace mlir { +namespace detail { + +constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); } + +constexpr unsigned nextPowerOf2(int n) { + return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); +} + +template struct Vector1D; + +template struct Vector1D { + Vector1D() { + static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; +}; + +// 1-D vector, padded to the next power of 2 allocation. +// Specialization occurs to avoid zero size arrays (which fail in -Werror). +template struct Vector1D { + Vector1D() { + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; +}; +} // namespace detail +} // namespace mlir + +// N-D vectors recurse down to 1-D. +template struct Vector { + inline Vector &operator[](unsigned i) { return vector[i]; } + inline const Vector &operator[](unsigned i) const { + return vector[i]; + } + +private: + Vector vector[Dim]; +}; + +// 1-D vectors in LLVM are automatically padded to the next power of 2. +// We insert explicit padding in to account for this. +template +struct Vector + : public mlir::detail::Vector1D { +}; + +template using Vector1D = Vector; +template using Vector2D = Vector; +template +using Vector3D = Vector; +template +using Vector4D = Vector; + +template void dropFront(int64_t arr[N], int64_t *res) { + for (unsigned i = 1; i < N; ++i) + *(res + i - 1) = arr[i]; +} + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for StridedMemRef type. +//===----------------------------------------------------------------------===// +template class StridedMemrefIterator; + +/// StridedMemRef descriptor type with static rank. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[N]; + int64_t strides[N]; + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == N && + "indices should match rank in memref subscript"); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + StridedMemRefType operator[](int64_t idx) { + StridedMemRefType res; + res.basePtr = basePtr; + res.data = data; + res.offset = offset + idx * strides[0]; + dropFront(sizes, res.sizes); + dropFront(strides, res.strides); + return res; + } +}; + +/// StridedMemRef descriptor type specialized for rank 1. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[1]; + int64_t strides[1]; + + template ().begin())> + T &operator[](Range indices) { + assert(indices.size() == 1 && + "indices should match rank in memref subscript"); + return (*this)[*indices.begin()]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } +}; + +/// StridedMemRef descriptor type specialized for rank 0. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + + template ().begin())> + T &operator[](Range indices) { + assert((indices.size() == 0) && + "Expect empty indices for 0-rank memref subscript"); + return data[offset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, offset + 1}; } +}; + +/// Iterate over all elements in a strided memref. +template class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(&descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::array indices = {}; + + /// Descriptor for the strided memref. + StridedMemRefType *descriptor; +}; + +/// Iterate over all elements in a 0-ranked strided memref. +template class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) + : elt(descriptor.data + offset) {} + + StridedMemrefIterator &operator++() { + ++elt; + return *this; + } + + reference operator*() { return *elt; } + pointer operator->() { return elt; } + + // There are no indices for a 0-ranked memref, but this API is provided for + // consistency with the general case. + const std::array &getIndices() { + // Since this is a 0-array of indices we can keep a single global const + // copy. + static const std::array indices = {}; + return indices; + } + + bool operator==(const StridedMemrefIterator &other) const { + return other.elt == elt; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Pointer to the single element in the zero-ranked memref. + T *elt; +}; + +//===----------------------------------------------------------------------===// +// Codegen-compatible structure for UnrankedMemRef type. +//===----------------------------------------------------------------------===// +// Unranked MemRef +template struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +//===----------------------------------------------------------------------===// +// DynamicMemRefType type. +//===----------------------------------------------------------------------===// +template class DynamicMemRefIterator; + +// A reference to one of the StridedMemRef types. +template class DynamicMemRefType { +public: + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; + + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(0), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(nullptr), strides(nullptr) {} + template + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(N), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {} + explicit DynamicMemRefType(const ::UnrankedMemRefType &memRef) + : rank(memRef.rank) { + auto *desc = static_cast *>(memRef.descriptor); + basePtr = desc->basePtr; + data = desc->data; + offset = desc->offset; + sizes = rank == 0 ? nullptr : desc->sizes; + strides = sizes + rank; + } + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == rank && + "indices should match rank in memref subscript"); + if (rank == 0) + return data[offset]; + + int64_t curOffset = offset; + for (int dim = rank - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + DynamicMemRefIterator begin() { return {*this, offset}; } + DynamicMemRefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + DynamicMemRefType operator[](int64_t idx) { + assert(rank > 0 && "can't make a subscript of a zero ranked array"); + + DynamicMemRefType res(*this); + --res.rank; + res.offset += idx * res.strides[0]; + ++res.sizes; + ++res.strides; + return res; + } + + // This operator* can be used in conjunction with the previous operator[] in + // order to access the underlying value in case of zero-ranked memref. + T &operator*() { + assert(rank == 0 && "not a zero-ranked memRef"); + return data[offset]; + } +}; + +/// Iterate over all elements in a dynamic memref. +template class DynamicMemRefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + DynamicMemRefIterator(DynamicMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(&descriptor) { + indices.resize(descriptor.rank, 0); + } + + DynamicMemRefIterator &operator++() { + if (descriptor->rank == 0) { + offset = -1; + return *this; + } + + int dim = descriptor->rank - 1; + + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + + if (dim < 0) { + offset = -1; + return *this; + } + + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::vector &getIndices() { return indices; } + + bool operator==(const DynamicMemRefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const DynamicMemRefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::vector indices = {}; + + /// Descriptor for the dynamic memref. + DynamicMemRefType *descriptor; +}; + +//===----------------------------------------------------------------------===// +// Small runtime support library for memref.copy lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, ::UnrankedMemRefType *src, + ::UnrankedMemRefType *dst); + +//===----------------------------------------------------------------------===// +// Small runtime support library for vector.print lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); + +//===----------------------------------------------------------------------===// +// Small runtime support library for timing execution and printing GFLOPS +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops); +extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock(); + +//===----------------------------------------------------------------------===// +// Runtime support library for random number generation. +//===----------------------------------------------------------------------===// +// Uses a seed to initialize a random generator and returns the generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s); +// Returns a random number in the range of [0, m). +extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *, uint64_t m); +// Deletes the random number generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *); + +//===----------------------------------------------------------------------===// +// Runtime support library to allow the use of std::sort in MLIR program. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortI64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF32(uint64_t n, StridedMemRefType *vref); +#endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/third_party/tsingmicro/include/ExecutionEngine/Msan.h b/third_party/tsingmicro/include/ExecutionEngine/Msan.h new file mode 100644 index 000000000..ee94660ae --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/Msan.h @@ -0,0 +1,35 @@ +//===- Msan.h - Utils related to the memory sanitizer ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares and defines macros related to msan. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_MSAN_H +#define MLIR_EXECUTIONENGINE_MSAN_H + +// Memory sanitizer currently can't be enabled for the jit-compiled code, and +// to suppress msan warnings we need to unpoison pointers and pointed-to +// datastructures before they can be accessed. + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_feature(memory_sanitizer) && !defined(MLIR_MEMORY_SANITIZER) +#define MLIR_MEMORY_SANITIZER +#endif + +#if defined(MLIR_MEMORY_SANITIZER) +#include +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) __msan_unpoison((p), (s)) +#else // Memory sanitizer: OFF +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) +#endif // MLIR_MEMORY_SANITIZER + +#endif // MLIR_EXECUTIONENGINE_MSAN_H diff --git a/third_party/tsingmicro/include/ExecutionEngine/version.txt b/third_party/tsingmicro/include/ExecutionEngine/version.txt new file mode 100644 index 000000000..c3f15e55e --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/version.txt @@ -0,0 +1 @@ +https://github.com/llvm/llvm-project/commit/3be3883e6d67bf908fd12b51219075293ebb3dff diff --git a/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td new file mode 100644 index 000000000..e930ab73a --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td @@ -0,0 +1,19 @@ +//===------------------- MagicKernelFuncOps.td ----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Common abstraction layer for non-instruction driven ML accelerator. +// +// The target glue layer that translates target independent kernel operations +// into NPU like APIs call (There are other jargons such as intrinsic or driver +// functions etc). +// +// The NPU APIs are categories by data type, like traditional compilers, integer +// and floating point function unit are separated, so for every MK(MagicKernel) +// op, it is lowered to 2 MKF(MagicKernelFunc) which are integer version and +// floating point version. +// +//===----------------------------------------------------------------------===// diff --git a/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td new file mode 100644 index 000000000..2dbad73eb --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td @@ -0,0 +1,13 @@ +//===------------------- MagicKernelInstrOps.td ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Common abstraction layer for instruction driven ML accelerator. +// +// The target glue layer that translates target independent kernel operations +// into intrinsics which fits LLVM dialect lowering path. +// +//===----------------------------------------------------------------------===// diff --git a/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt new file mode 100644 index 000000000..cece7d89b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(LinalgToMK) +add_subdirectory(CoreDialectsToMK) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt new file mode 100644 index 000000000..69690dec7 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt @@ -0,0 +1,10 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +# All rights reserved. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name CoreDialectsMK) +add_public_tablegen_target(CoreDialectsToMKConversionPassIncGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h new file mode 100644 index 000000000..69750d402 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h @@ -0,0 +1,27 @@ +//===------------------- CoreDialectsToMK.h -------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass is the wrap all pass that populates all the conversion patterns +// from core dialects such as linalg, memref, buf etc to mk dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H +#define TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createCoreDialectsToMKPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h new file mode 100644 index 000000000..7e1982c3b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h @@ -0,0 +1,26 @@ +//===------------------- CoreDialectsToMK.h -------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Wrap all the conversion from core dialects to backend dialects(MK etc). +// +//===----------------------------------------------------------------------===// + +#ifndef CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H +#define CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H + +#include "magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td new file mode 100644 index 000000000..d4f5fa677 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td @@ -0,0 +1,18 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef CORE_DIALECTS_TO_MK_CONVERSION_PASSES +#define CORE_DIALECTS_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def CoreDialectsToMK : Pass<"core-dialects-to-mk", "mlir::ModuleOp"> { + let summary = "Convert core dialects including Linalg, Memref etc to MK"; + let constructor = "triton::createCoreDialectsToMKPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt new file mode 100644 index 000000000..76b9d9114 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name LinalgToMK) +add_public_tablegen_target(LinalgToMKConversionPassIncGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h new file mode 100644 index 000000000..031593c26 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h @@ -0,0 +1,35 @@ +//===------------------- LinalgToMK.h -------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering all linalg ops into mk ops. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_LINALG_TO_MK_H +#define ZTC_CONVERSION_LINALG_TO_MK_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + +void populateLinalgToMKCanonicalizationPatterns(RewritePatternSet &patterns); + +void populateLinalgToMKConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr> createLinalgToMKPass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h new file mode 100644 index 000000000..7c45210e6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TO_MK_CONVERSION_PASSES_H +#define LINALG_TO_MK_CONVERSION_PASSES_H + +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // LINALG_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td new file mode 100644 index 000000000..b4f39500c --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td @@ -0,0 +1,19 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TO_MK_CONVERSION_PASSES +#define LINALG_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def LinalgToMK : Pass<"linalg-to-mk", "mlir::ModuleOp"> { + let summary = "Convert linalg operations into magic kernel operations"; + + let options = []; +} + +#endif diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt new file mode 100644 index 000000000..437811f2a --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS MagicKernelOps.td) +mlir_tablegen(MagicKernelDialect.h.inc -gen-dialect-decls -dialect=mk) +mlir_tablegen(MagicKernelDialect.cpp.inc -gen-dialect-defs -dialect=mk) +mlir_tablegen(MagicKernelOps.h.inc -gen-op-decls) +mlir_tablegen(MagicKernelOps.cpp.inc -gen-op-defs) + +set(LLVM_TARGET_DEFINITIONS MagicKernelTypes.td) +mlir_tablegen(MagicKernelTypes.h.inc -gen-typedef-decls) +mlir_tablegen(MagicKernelTypes.cpp.inc -gen-typedef-defs) + +add_public_tablegen_target(MagicKernelTableGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td new file mode 100644 index 000000000..666a7f414 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td @@ -0,0 +1,15 @@ +//===------------------- MagicKernelAttrDefs.td ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_ATTR_DEFS +#define MAGIC_KERNEL_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + + + +#endif // MAGIC_KERNEL_ATTR_DEFS diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h new file mode 100644 index 000000000..06bd269a3 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h @@ -0,0 +1,32 @@ +//===------------------- MagicKernelDialect.h -----------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ +#define MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// MagicKernel Operations +//===----------------------------------------------------------------------===// +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "magic-kernel/Dialect/IR/MagicKernelOps.h.inc" + +#endif // MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td new file mode 100644 index 000000000..4aee43bb8 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td @@ -0,0 +1,44 @@ +//===------------------- MagicKernelDialect.td ----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_DIALECT +#define MAGIC_KERNEL_DIALECT + +include "mlir/IR/OpBase.td" + +def MagicKernelDialect : Dialect { + let name = "mk"; + + let cppNamespace = "::mlir::mk"; + + let summary = "The Magic Kernel IR in MLIR"; + + let description = [{ + Magic Kernel Dialect. + + Dependent Dialects: + * Memref + * copy, alloc + * Bufferization + * to_tensor + }]; + + let dependentDialects = [ + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + // let hasConstantMaterializer = 1; + // let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "magic-kernel/Dialect/IR/MagicKernelTypes.td" + +#endif // MAGIC_KERNEL_DIALECT diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td new file mode 100644 index 000000000..f49785fa5 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td @@ -0,0 +1,284 @@ +//===------------------- MagicKernelOps.td --------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// The abstract layer between MLIR core dialects and the lower target specific +// dialects of MagicKernelFunc and MagicKernelInstr. +// +// Compare to higher level MLIR dialects such as memref, arith, affine etc, the +// granularity of MK dialect is more suitable to map into ML accelerators. For +// example, tt.load is lowered to arith + memref.reinterpret_cast + memref.alloc +// + memref.copy + bufferization.to_tensor by decoding hidden high level info +// into detailed info carried in those core MLIR dialects. +// If we convert tt.load to mk.alloc + mk.load, we have to redo all the analysis +// and info constructions which triton-shared already does, so that we should +// generate mk.alloc + mk.load from the core dialects to avoid reconstructing +// the information. +// By doing so, we can lower arith + memref.reinterpret_cast + memref.copy + +// buf.to_tensor into mk.load, and lower arith + memref.alloc into mk.alloc. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_OPS +#define MAGIC_KERNEL_OPS + +include "magic-kernel/Dialect/IR/MagicKernelTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Bufferable type. +//===----------------------------------------------------------------------===// + +def TensorOrMemref : + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + + +class MKOp traits = []> : + Op { +} + +class MKUnElemWiseOp : MKOp { + let summary = "Element wise unary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + +class MKBinElemWiseOp : MKOp { + let summary = "Element wise binary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src0, + AnyTensor:$src1, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + +class MKTerElemWiseOp : MKOp { + let summary = "Element wise binary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src0, + AnyTensor:$src1, + AnyTensor:$src2, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + + +// ============================================================================= +// Memory allocation ops +// ============================================================================= + +def AllocOp : MKOp<"alloc", []> { + let summary = "Allocate a consecutive memory from given addressing space"; + + let description = [{ + It may or may not generate target intrinsic call or instruction, the + lowering from this operator to lower level operator is target specific. + }]; + + let arguments = ( + ins + I32Attr:$addr_space, // The addressing space + I64ArrayAttr:$dims // The size of memory to be allocated + ); + + // Return the pointer of the allocated memory + let results = (outs AnyRankedOrUnrankedMemRef:$ptr); +} + +// ============================================================================= +// Load/Store Ops +// ============================================================================= + +// Unit and strided memory load +def LoadOp : MKOp<"load", []> { + let summary = "Load from a memory with optional strides"; + + let description = [{ See RISC-V RVV unit/strided memory load for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be loaded, can be dynamic + I64ArrayAttr:$strides, // The strides in each rank, can be dynamic + BoolAttr:$mask // element is not loaded if mask[i] == 0 + ); + + let results = (outs AnyTensor:$result); + + let assemblyFormat = [{ + $ptr `,` attr-dict `:` type($ptr) `->` type($result) + }]; +} + +// Index memory load +def IndexLoadOp : MKOp<"iload", [ +]> { + let summary = "Load from a memory with indexed offset"; + + let description = [{ See RISC-V RVV index memory load for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be loaded + AnyTensor:$index, // The tensor contains memory offset for each element + BoolAttr:$mask // element is not loaded if mask[i] == 0 + ); + + let results = (outs MKType:$result); +} + +// Unit and strided memory store +def StoreOp : MKOp<"store", [MemoryEffects<[MemWrite]>]> { + let summary = "Store to a memory with optional strides"; + + let description = [{ See RISC-V RVV unit/strided memory store for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be stored + I64ArrayAttr:$strides, // The strides in each rank + BoolAttr:$mask // element is not write to dest if mask[i] == 0 + ); + + let assemblyFormat = [{ + $ptr `,` attr-dict `:` type($ptr) + }]; +} + +// Index memory store +def IndexStoreOp : MKOp<"istore", [MemoryEffects<[MemWrite]>]> { + let summary = "Store to a memory with indexed offset"; + + let description = [{ See RISC-V RVV index memory store for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be stored + AnyTensor:$index, // The tensor contains memory offset for each element + BoolAttr:$mask // element is not write to dest if mask[i] == 0 + ); +} + +// ============================================================================= +// Dot op +// ============================================================================= + +def DotOp : MKOp<"dot", [DestinationStyleOpInterface]> { + let summary = "Inner production of 2 vectors"; + + let description = [{ + TODO: It is currently one to one mapping from upper dialect tt.dot. + }]; + + let arguments = ( + ins + TensorOrMemref:$a, // Matrix A + TensorOrMemref:$b, // Matrix B + Optional:$c, // Optional accumulation matrix C + // Zeroes buffer which can be used to fill $d + // FIXME: Whether need add side effect to source operands? + Arg:$zeroes + //DefaultValuedAttr:$inputPrecision, + // DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs Variadic:$d); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getZeroesMutable(); + } + }]; + + // let hasVerifier = 1; +} + +// ============================================================================= +// Reduction ops +// ============================================================================= + +def ArgMaxOp : MKOp<"argmax", [Pure]> {} +def ArgMinOp : MKOp<"argmin", [Pure]> {} +def ReduceMaxOp : MKOp<"reduce_max", [Pure]> {} +def ReduceMinOp : MKOp<"reduce_min", [Pure]> {} +def ReduceOp : MKOp<"reduce", [Pure]> {} +def SumOp : MKOp<"sum", [Pure]> {} +def XorSumOp : MKOp<"xor_sum", [Pure]> {} + + +// ============================================================================= +// Scan/Sort Ops +// ============================================================================= + +def SortOp : MKOp<"sort", [Pure]> {} +def GatherOp : MKOp<"gather", [Pure]> {} + + +// ============================================================================= +// Unary/Binary/Ternary Element-wise Math Ops +// ============================================================================= + +def AbsOp : MKUnElemWiseOp<"abs">; +def AddOp : MKBinElemWiseOp<"add">; +def AndOp : MKBinElemWiseOp<"and">; +def CDivOp : MKBinElemWiseOp<"cdiv">; +def CeilOp : MKUnElemWiseOp<"ceil">; +def ClampOp : MKUnElemWiseOp<"clamp">; +def CosOp : MKUnElemWiseOp<"cos">; +def DivOp : MKBinElemWiseOp<"div">; +def ErfOp : MKUnElemWiseOp<"erf">; +def ExpOp : MKUnElemWiseOp<"exp">; +def Exp2Op : MKUnElemWiseOp<"exp2">; +def FdivOp : MKBinElemWiseOp<"fdiv">; +def FloorOp : MKUnElemWiseOp<"floor">; +def FmaOp : MKTerElemWiseOp<"fma">; +def LogOp : MKUnElemWiseOp<"log">; +def Log2Op : MKUnElemWiseOp<"log2">; +def MaxOp : MKUnElemWiseOp<"max">; +def MinOp : MKUnElemWiseOp<"min">; +def OrOp : MKBinElemWiseOp<"or">; +def RsqrtOp : MKUnElemWiseOp<"rsqrt">; +def SigmoidOp : MKUnElemWiseOp<"sigmoid">; +def SinOp : MKUnElemWiseOp<"sin">; +def SqrtOp : MKUnElemWiseOp<"sqrt">; +def SqrtRnOp : MKUnElemWiseOp<"sqrt_rn">; +def XorOp : MKBinElemWiseOp<"xor">; +// def UmulhiOp : MKOp<"umulhi", [Pure]> {} + + +#endif // MAGIC_KERNEL_OPS diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td new file mode 100644 index 000000000..19fb9e1b6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td @@ -0,0 +1,102 @@ +//===------------------- MagicKernelTypes.td ------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_TYPES_TD +#define MAGIC_KERNEL_TYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "magic-kernel/Dialect/IR/MagicKernelDialect.td" + +// +// Types +// +class MKTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def MKFloat : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def MKFloatTensor : RankedTensorOf<[MKFloat]>; +def MKFloatLike : AnyTypeOf<[MKFloat, MKFloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def MKBoolTensor : RankedTensorOf<[I1]>; +def MKBoolLike : AnyTypeOf<[I1, MKBoolTensor]>; + +// Integer Type +def I4 : I<4>; +def MKInt : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def MKIntTensor : RankedTensorOf<[MKInt]>; +def MKIntLike : AnyTypeOf<[MKInt, MKIntTensor]>; + +// I32 Type +// MKI32 -> I32 +// MKI32Tensor -> I32Tensor +def MKI32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// MKI64 -> I64 +// MKI64Tensor -> I64Tensor +def MKI64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class MKPtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `MKPtrOf`) +def MKPtrType : MKTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def MKPtr : MKPtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def MKPtrTensor : RankedTensorOf<[MKPtr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def MKPtrLike : AnyTypeOf<[MKPtr, MKPtrTensor]>; + +// Tensor Type +def MKFpIntTensor : RankedTensorOf<[MKFloat, MKInt]>; +def MKTensor : RankedTensorOf<[MKFloat, MKInt, MKPtr]>; + +// Pointer Type to Tensor Type: `ptr>` +def MKTensorPtr : MKPtrOf<[MKTensor]>; + +// Any Type in Magic Kernel IR +def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; + +#endif // MAGIC_KERNEL_TYPES_TD diff --git a/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h b/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..789a0c8a8 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,26 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file declares the implementation of the BufferizableOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H +#define _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace mk { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace mk +} // namespace mlir + +#endif // _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h new file mode 100644 index 000000000..6d310d93d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_MASKANALYSIS_H +#define TRITON_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/LogicalResult.h" + +#include + +namespace mlir { + +class OpBuilder; + +namespace triton { +// Data structure used to decode the pattern in a mask used for load and store. +// start and end field represent the start and end index of a range (produced +// by make_range, addi, etc.). While multi-dimensional data is possible, we +// assume range comparison can only be done on 1 dimension at a time (and +// results of range comparions across dimensions can be combined), hence start +// and end are not vectors. dims represents the real access size for ld/st +// (instead of the tensor/memref size specified by the IR). scalar is a shortcut +// used when the entire state contains a single scalar value. +// +// The general lifetime of this data structure is roughly: +// 1. A range is created by make_range and optionally operated on by addi w/ +// result of splat, expand_dims, etc. During this phase, either (1) both start +// and end are populated, or (2) scalar is populated. Only one of the dimensions +// (that contains the range) can have dim > 1. +// 2. Result from step 1 is compared with a another MaskState that represents a +// scalar value. The resulting state only has dims populated. +// 3. Optionally, result from step 2 can be broadcasted and anded with other +// results from step 2. The resulting state only has dims populated. +// +// Example of creating 2D mask: +// mask = (rows[:, None] < M) & (cols[None, :] < N) +struct MaskState { + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + OpFoldResult scalar; + const bool useUnsafeMask; + + void dump() const; + + MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {} + + int64_t getRank() const { return dims.size(); } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { return !start && !end && !scalar && dims.size() != 0; } + + // Recursively parse a Value; call the coresponding function based on the + // defining operation and Value type + LogicalResult parse(Value operand, const Location loc, OpBuilder &builder); + + tensor::ExtractSliceOp getExtractSlice(Value source, const Location loc, + OpBuilder &builder) const; + + memref::SubViewOp getSubview(Value source, const Location loc, + OpBuilder &builder) const; + + std::pair + getSideBySideSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const; + + std::pair + getStackedSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const; + +private: + // ------- + // Utility functions to operate on MaskState + // ------- + LogicalResult addStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, + Location loc, OpBuilder &builder); + + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, + Location loc, OpBuilder &builder); + // ------- + // Helper functions to parse values to populate MaskState + // ------- + + LogicalResult parseExtSI(arith::ExtSIOp op, const Location loc, + OpBuilder &builder); + + // Operand is the result of a constant + // Get the value of the constant and assign it to scalar. + LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder); + + // Operand is the result of addi + // One and only one of the operands should be a scalar. Increment both start + // and end, dims remains unchanged, and scalar is empty. + LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder); + // Operand is the result of andi + // Each of the result state dims is smaller of the two operands' dims. + // Insert instruction if needed to get new dims. + LogicalResult parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder); + + // Operand is the result of cmpi + // Assume only one of the dimensions has size > 1. Only support slt/ult, and + // sge against 0 for now. For that dimension, we have three cases: + // 1. Constant comparison with both left and right-hand sides being scalars. + // Calculate this new dim as a compare and select. + // I.e. dim = lhs < rhs ? end : 0 + // 2. Left-hand side is not a scalar, and the right-hand side is. + // 2.a. Predicate is slt/ult. Calculate this new dim as: + // dim = max(min(end, value), start) - start + // 2.b. Predicate is sge against 0. Mask analysis already has an + // assumption that the mask starts at 0, so evaluate this to true + // and calculate this new dim as: dim = end + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder); + // Operand is the result of make_range + // Set start and end accordingly; step size must be 1. + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, + OpBuilder &builder); + // Operand is the result of broadcast + // Change dims only; assume only applies to tensors. + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, OpBuilder &builder); + // Operand is the result of splat + // Assume only applies to scalar. start and end are left empty; scalar will + // be assigned, and dims will be updated. + LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder); + // Operand is the result of expand_dims + // Insert additional dims; start and end do not change and correspond to the + // dimension that contains the range. + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, OpBuilder &builder); + + LogicalResult parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h new file mode 100644 index 000000000..1a66d6ee2 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H +#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" + +#include + +namespace mlir { + +class OpBuilder; + +// Return integer if ofr is an IntegerAttr. Note that this function differs +// from getConstantIntValue, which returns an integer if ofr is the constant +// result of an operation too. +std::optional getIntAttr(const OpFoldResult ofr); + +// Return if ofr contains a constant zero, either represented by an integer +// attribute or a constant value. +bool hasConstZero(const OpFoldResult ofr); + +// Create a value of index type if necessary from an OpFoldResult. +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b); + +// Create a vector of values of index type if necessary from an array of +// OpFoldResults. +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b); + +// Process addition of two OFRs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +// Produce result = lhs - rhs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +// Process multiplication of two OFRs. If both OFRs are Integer Attributes, +// result is an Integer Attribtue. Otherwise, insert the arith.muli +// instruction if needed and use its result Value. +OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, + const Location loc, OpBuilder &b); + +OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const arith::CmpIPredicate pred, + const OpFoldResult trueVal, + const OpFoldResult falseVal, const Location loc, + OpBuilder &b); +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h new file mode 100644 index 000000000..5a95ebda9 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h @@ -0,0 +1,271 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_PTRANALYSIS_H +#define TRITON_ANALYSIS_PTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +class ConversionPatternRewriter; + +namespace triton { + +struct ModuloState { + Value size; + + // offset is used to determine the wraparound point for patterns like: + // offset + (tl.arange(0, 256) % 12) + // The current code assumes that the modulo operator always runs last, e.g: + // (offset + tl.arange(0, 256)) % 12 + // This is not used at the moment as there haven't been enough use cases and + // the implementation is quite complex. + // OpFoldResult offset; + + static constexpr char const *WraparoundAttr = "ptr.wraparound_type"; + static constexpr char const *WraparoundStacked = "stacked"; + static constexpr char const *WraparoundSideBySide = "side_by_side"; +}; + +// Data structure used to decode pointer arithmetics and potentially to be +// translate it into memref. offsets, sizes, and strides are in unit of elements +// in a linearly laid-out memory, which is the same as pointer arithmetic +// operations in Triton language. scalar is a shortcut used when the entire +// state describes a single scalar value. source is the base pointer. +class PtrState { + + OpFoldResult + accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const; + +public: + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + + SmallVector> modulos; + + Value source; + Value scalar; + + int64_t getRank() const; + + bool isEmpty() const; + + bool hasModulo() const; + + MemRefType getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape, + bool useDynamicStrides = false) const; + + // Process addition of two PtrStates. + void addState(const PtrState &lhsState, const PtrState &rhsState, + Location loc, ConversionPatternRewriter &rewriter); + + // Process multiplication of two PtrStates + void mulState(const PtrState &lhsState, const PtrState &rhsState, + const Location loc, ConversionPatternRewriter &rewriter); + + // Produce a reinterpret cast based on the current PtrState. Additional + // instructions may be inserted in calculating the final offset. + memref::ReinterpretCastOp + createCastOp(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createSideBySideCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createStackedCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; +}; + +class PtrAnalysis { +public: + using IndexMapSet = std::map>; + + // Recursively parse a Value; call the corresponding + // function based on the defining operation and argument type. + static void + visitOperand(Value operand, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + static void + visitOperandAdd(arith::AddIOp addOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + static void + visitOperandMul(arith::MulIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void + visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + static void + visitOperandMakeRange(triton::MakeRangeOp rangeOp, PtrState &state, + Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + static void + visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, PtrState &state, + const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of soure and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + static void + visitOperandBroadcast(triton::BroadcastOp broadcastOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + static void + visitOperandSplat(triton::SplatOp splatOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + static void + visitOperandConstSplat(arith::ConstantOp op, PtrState &state, + const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset wil be added + static void + visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of reinterpret_cast. + // Main assumptions: + // None + // Expected result: + // Directly grab all corresponding fields from reinterpret_cast. + static void + visitOperandReintCast(memref::ReinterpretCastOp reintCastOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of tt.advance. + // Main assumptions: + // The source of the tt.advance has been mapped to a reinterpret_cast + // Expected result: + // Directly grab all corresponding fields from reinterpret_cast. + // Add the offsets multiplied by the strides to the final offsets. + static void rewriteAdvanceOp(triton::AdvanceOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs); + + // Parse the state of AddPtrOp, insert any instruction needed to + // calculate strides and offsets, build PtrState for this operand, and record + // PtrState for knownPtrs. + static void rewriteAddptrOp(triton::AddPtrOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + static void + rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &knownPtrs); + + static void rewriteForOp(scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &knownPtrs); + + static Value getScalarMemRef(Value ptr, Value memRef, const Location loc, + ConversionPatternRewriter &rewriter); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h new file mode 100644 index 000000000..39c3055a5 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h @@ -0,0 +1,119 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_USEANALYSIS_H +#define TRITON_ANALYSIS_USEANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createTritonUseAnalysisPass(); + +enum class UseType { + Undefined, // Initial state + DataUse, // value used for tensor computation only + MetaUse, // value used for metadata only + MixUse // value used for both tensor computation and metadata +}; + +struct UseInfo : public dataflow::AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UseInfo) + using AbstractSparseLattice::AbstractSparseLattice; + + // Lattice state transfer function + ChangeResult meetUseType(const UseType &other) { + if (other == UseType::Undefined) + return ChangeResult::NoChange; + + switch (type) { + case UseType::Undefined: + type = other; + return ChangeResult::Change; + case UseType::DataUse: + case UseType::MetaUse: + if (type == other) { + return ChangeResult::NoChange; + } else { + type = UseType::MixUse; + return ChangeResult::Change; + } + case UseType::MixUse: + return ChangeResult::NoChange; + default: + llvm_unreachable("bad type"); + } + } + + ChangeResult meet(const AbstractSparseLattice &other) override { + auto rhs = reinterpret_cast(&other); + return meetUseType(rhs->type); + } + + void print(raw_ostream &os) const override { + switch (type) { + case UseType::DataUse: + os << "DataUse"; + break; + case UseType::MetaUse: + os << "MetaUse"; + break; + case UseType::MixUse: + os << "MixUse"; + break; + default: + os << "Undefined"; + } + } + + UseType type = UseType::Undefined; +}; + +class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override { return; } + + void visitCallOperand(OpOperand &operand) override { return; } + + void setToExitState(UseInfo *lattice) override { + lattice->type = UseType::Undefined; + } + +private: + void propagateUse(UseInfo *lattice, const UseType &type) { + auto changed = lattice->meetUseType(type); + propagateIfChanged(lattice, changed); + } + + void propagateResults(UseInfo *lattice, ArrayRef results) { + auto changed = ChangeResult::NoChange; + for (auto result : results) + changed |= lattice->meet(*result); + propagateIfChanged(lattice, changed); + } +}; + +// Use SparseBackwardDataAnalysis to identify operations whose results are used +// as data tensor operations, meta operations (address calculation, +// broadcasting/splating constant, etc.), or both. For operations used as both +// purposes, clone them so that the remaining pass built on +// ConversionPatternRewriter can replace all tensor producers cleanly and simply +// delete meta data producers. +LogicalResult runUseAnalysis(triton::FuncOp &funcOp); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOAFFINE_TRITONUSEANALYSIS_H diff --git a/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h new file mode 100644 index 000000000..e17cacd6a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -0,0 +1,274 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H +#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include + +namespace mlir { + +class OpBuilder; + +namespace tts { + +const extern std::string ptrAnalysisAttr; + +// Data structure used to decode pointer arithmetics. offsets, sizes, and +// strides are in unit of elements in a linearly laid-out memory, which is the +// same as pointer arithmetic operations in Triton language. scalar is a +// shortcut used when the entire state describes a single scalar value. source +// is the base pointer. If order is present, PtrState describes block pointer; +// otherwise it describes non-block pointers. When it describes block pointer, +// shape field means the same field as tt.make_tensor_ptr; when it describes a +// non-block pointer, shape field indicates how address wraps around (i.e., +// modulo); a constant 0 indicates no modulo for the dimension. +struct PtrState { + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + SmallVector shape; + SmallVector order; + + Value source; + Value scalar; + + int32_t getRank() const; + + bool isEmpty() const; + + bool hasModulo() const; + + bool dimHasModulo(uint32_t dim) const; + + bool isBlockPtr() const; + + void dump() const; + + // Process addition of two PtrStates. + LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder); + + // Process multiplication of two PtrStates + LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder); + + tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc); +}; + +class PtrAnalysis { + // This function is internally used by getLoopIterArgPtrState and + // getLoopResultPtrState to get the correct PtrState for either an iter-arg or + // a loop's result. + // + // A PtrState of an scf.for's iter-arg is the same as its corresponding + // init-arg, except that the strides and offsets have to point to the loop's + // iter-args that were created to carry the offsets and strides. + // + // For instance, for a pointer with index i and rank 2, 4 additional args + // starting at index i + 1 are created. The PtrState's strides and offsets + // value of the pointer's iter-arg must point to these 4 additionally created + // iter-args. + // + // A similar process is used for getting the PtrState of the loop's i'th + // result: its strides and offsets have to point to the corresponding stride + // and offset values returned by the loop. + PtrState reconcileLoopPtrState( + scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state, + llvm::function_ref getReplacementVal); + + DenseSet maybeStructuredArgs; + +public: + void initializeMaybeStructuredArgs(Operation *op); + + llvm::SmallDenseMap knownPtrs; + + IRMapping ptrMap; + + // Recursively parse a Value; call the corresponding + // function based on the defining operation and argument type. + LogicalResult visitOperand(Value operand, PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is a result of an scf.for. Such cases occur when there are multiple + // levels of nested loops where the results of the inner scf.for (pointer) are + // yielded by the outer loop. + LogicalResult visitOperandForOp(scf::ForOp forOp, Value operand, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state, + const Location loc, OpBuilder &builder); + + LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrState &state, Location loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of soure and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state, + const Location loc, OpBuilder &builder); + + LogicalResult visitOperandExtSI(arith::ExtSIOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset wil be added + LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of tts.make_tptr. + // Main assumptions: + // This function is only called when rewriting a loop + // Expected result: + // Directly grab all corresponding fields from tts.make_tptr. + LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of tt.make_tensor_ptr. + // Expected result: + // Parse source pointer and grab results + LogicalResult visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Get the computed PtrState for the forOp's init-arg at the provided index. + FailureOr getLoopInitArgPtrState(scf::ForOp forOp, size_t index); + + // Get the computed PtrState for the forOp's iter-arg at the provided index. + FailureOr getLoopIterArgPtrState(scf::ForOp forOp, size_t index); + + // Get the computed PtrState for the forOp's result at the provided index. + FailureOr getLoopResultPtrState(scf::ForOp forOp, size_t index); + + // After PtrAnalysis finishes, rewrite the GetStructuredStateOp by creating + // the correct initialization ops for offsets and strides and passing them to + // any loop's init-args. + LogicalResult rewriteGetStructuredStateOp(tts::GetStructuredStateOp op); + + // Parse the state of AddPtrOp, insert any instruction needed to + // calculate strides and offsets, build PtrState for this operand, and record + // PtrState for knownPtrs. + LogicalResult rewriteAddptrOp(triton::AddPtrOp op); + + LogicalResult rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op); + + LogicalResult rewriteAdvanceOp(triton::AdvanceOp op); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + LogicalResult + rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor); + + // Rewrite eligible tt.addptr in loop init args so loop can update the such + // pointers over iterations. Insert any instruction needed to calculate + // strides, offsets, and modulos. + LogicalResult rewriteForOp(scf::ForOp op); + + LogicalResult rewriteLoadOp(triton::LoadOp op, bool useUnsafeMask = false); + + LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false); + + LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false); +}; + +} // namespace tts + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt new file mode 100644 index 000000000..60180abfb --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) +add_subdirectory(TritonArithToLinalg) +add_subdirectory(StructuredToMemref) +add_subdirectory(TritonToCoreDialects) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt new file mode 100644 index 000000000..83ff64d36 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name StructuredToMemref) +add_public_tablegen_target(StructuredToMemrefConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h new file mode 100644 index 000000000..198675b12 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_STRUCTURED_TO_MEMREF_CONVERSION_PASSES_H +#define TRITON_STRUCTURED_TO_MEMREF_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td new file mode 100644 index 000000000..0f2f08a6d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td @@ -0,0 +1,10 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def StructuredToMemref : Pass<"structured-to-memref", "mlir::ModuleOp"> { + let summary = "Convert triton structured pointer ops to memref"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h new file mode 100644 index 000000000..8c67c9ec0 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h @@ -0,0 +1,24 @@ +#ifndef TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H +#define TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +class TypeConverter; +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" + +void populateStructuredToMemrefConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); + +std::unique_ptr> createStructuredToMemrefPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt new file mode 100644 index 000000000..85076bd1c --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonArithToLinalg) +add_public_tablegen_target(TritonArithToLinalgConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h new file mode 100644 index 000000000..82214faf0 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h @@ -0,0 +1,2126 @@ +#ifndef TRITON_CONVERSION_PATTERNS +#define TRITON_CONVERSION_PATTERNS + +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Analysis/PtrAnalysis.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace triton; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +static SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +static Value getTransposedValue(Value source, const Location loc, + ConversionPatternRewriter &rewriter) { + + auto sourceType = cast(source.getType()); + auto sourceRank = sourceType.getRank(); + + SmallVector perm(sourceRank); + std::iota(std::begin(perm), std::end(perm), 0); + std::swap(perm[sourceRank - 1], perm[sourceRank - 2]); + + SmallVector transposedShape(sourceType.getShape()); + std::swap(transposedShape[sourceRank - 1], transposedShape[sourceRank - 2]); + + Value transposeInit = rewriter.create( + loc, transposedShape, sourceType.getElementType()); + + Value transpose = + rewriter.create(loc, source, transposeInit, perm) + .getResults()[0]; + + return transpose; +} + +// for IntLike and FloatLike types +static std::optional getBitWidth(Type a) { + if (auto type = dyn_cast(a)) { + auto elementType = type.getElementType(); + if (elementType.isIntOrFloat()) { + return type.getElementType().getIntOrFloatBitWidth(); + } + return std::nullopt; + } + + if (a.isIntOrFloat()) + return a.getIntOrFloatBitWidth(); + + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// Op Lowering Patterns +//===----------------------------------------------------------------------===// + +namespace { + +//----------------------------- +// Begin of monolithic only +//----------------------------- +struct AdvanceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrState pointerState; + PtrAnalysis::rewriteAdvanceOp(op, rewriter, knownPtrs); + return success(); + } +}; + +struct MakeTensorPtrConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + void populateVectorAsIndex(SmallVector &vec, + Operation::operand_range ops, + ConversionPatternRewriter &rewriter, + Location loc) const { + for (auto opnd : ops) { + if (isa(opnd.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), opnd); + vec.push_back(castOp.getResult()); + } else { + assert(isa(opnd.getType())); + vec.push_back(opnd); + } + } + } + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + PtrState pointerState; + + auto orderSize = op.getOrder().size(); + if (orderSize > 1) { + for (auto [first, second] : + llvm::zip(op.getOrder().slice(0, orderSize - 2), + op.getOrder().slice(1, orderSize - 1))) { + assert(first == second + 1 && + "Currently only support default order on block pointers"); + } + } + + pointerState.source = rewriter.getRemappedValue(op.getBase()); + populateVectorAsIndex(pointerState.offsets, op.getOffsets(), rewriter, loc); + populateVectorAsIndex(pointerState.strides, op.getStrides(), rewriter, loc); + + SmallVector newOffsets; + for (auto [offset, stride] : + llvm::zip(pointerState.offsets, pointerState.strides)) { + auto mulOp = rewriter.create(loc, cast(offset), + cast(stride)); + newOffsets.push_back(mulOp.getResult()); + } + + pointerState.offsets.clear(); + + for (auto offset : newOffsets) { + pointerState.offsets.push_back(offset); + } + + ArrayRef resultShape; + auto pointerType = + cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + for (auto dim_size : resultShape) { + pointerState.sizes.push_back( + IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); + } + } else { + // scalar pointer, should produce a one dimensional memref + SmallVector scalarShape(1, 1); + resultShape = scalarShape; + assert(pointerState.getRank() == 1); + } + + auto castOp = pointerState.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, castOp.getResult()); + return success(); + } +}; + +struct LegacyAddPtrConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrAnalysis::rewriteAddptrOp(op, rewriter, knownPtrs); + return success(); + } +}; + +struct LoadConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + void createSideBySideCopies(Value block1, Value block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + void createStackedCopies(Value block1, Value block2, Value dst, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + +public: + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ptr = adaptor.getPtr(); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + // 0. Shortcut for scalar loads + if (!isa(op.getResult().getType())) { + auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), adaptor.getPtr(), + loc, rewriter); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + auto loadOp = rewriter.create( + op.getLoc(), sMemRef, zeroMap, std::nullopt); + rewriter.replaceOp(op, loadOp.getResult()); + return success(); + } + + // 1. Simple case where no mask is used. + auto type = dyn_cast(ptr.getType()); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of operations. + // The workaround is to broadcast the pointers early in the address + // calculation. A proper fix is complicated, but at least we can provide a + // better error message. + return rewriter.notifyMatchFailure( + op, "LoadOp expects a memref, not a memref of pointers"); + } + + auto tensorType = + RankedTensorType::get(type.getShape(), type.getElementType()); + auto alloc = rewriter.create( + loc, MemRefType::get(type.getShape(), type.getElementType())); + + if (!mask) { + assert(!other && "other value used in non-masked load"); + if (auto unrealizedCast = + ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (wrapType.getValue() == ModuloState::WraparoundSideBySide) { + createSideBySideCopies(block1, block2, alloc, loc, rewriter); + } else if (wrapType.getValue() == ModuloState::WraparoundStacked) { + createStackedCopies(block1, block2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + rewriter.create(loc, ptr, alloc); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + + // 2. Continuous masked loads. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "Cannot lower continuous masked loads"); + } + + // fill load destination with other value + if (other) { + auto scalarOther = getScalarValue(other, loc, rewriter); + assert(scalarOther && "other value used in masked load produced by " + "unsupported instruction"); + + // For each dimension check if mstate.dims[i] < shape[i], or-accumulate + // the result + auto shape = type.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < type.getShape().size(); i++) { + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + + Value dimi = dyn_cast(mstate.dims[i]); + if (!dimi) { + dimi = rewriter.create( + loc, cast(cast(mstate.dims[i]))); + } + + auto cmpOp = rewriter.create( + loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = rewriter.create(loc, accBase, cmpOp.getResult()) + .getResult(); + } + + // condition the memset on the or-accumulation + // initialize with padding prior to CopyOp + rewriter.create( + loc, accBase, [&](OpBuilder &builder, Location loc) { + builder.create(loc, ValueRange{scalarOther}, + ValueRange{alloc}); + builder.create(loc); + }); + } + + if (auto unrealizedCast = ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (wrapType.getValue() == ModuloState::WraparoundSideBySide) { + auto [subview1, subview2] = + mstate.getSideBySideSubviews(block1, block2, loc, rewriter); + + createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); + } else if (wrapType.getValue() == ModuloState::WraparoundStacked) { + auto [subview1, subview2] = + mstate.getStackedSubviews(block1, block2, loc, rewriter); + + createStackedCopies(subview1, subview2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + memref::SubViewOp srcSubview = mstate.getSubview(ptr, loc, rewriter); + memref::SubViewOp dstSubview = mstate.getSubview(alloc, loc, rewriter); + rewriter.create(loc, srcSubview, dstSubview); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } +}; + +struct StoreConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ptr = adaptor.getPtr(); + auto val = adaptor.getValue(); + auto mask = op.getMask(); + auto loc = op.getLoc(); + + // 0. Shortcut for scalar stores + if (!isa(val.getType())) { + auto sMemRef = + PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + rewriter.create(loc, val, sMemRef, zeroMap, + std::nullopt); + rewriter.eraseOp(op); + return success(); + } + + // 1. Simple case where no mask is used. + if (!mask) { + auto storeOp = rewriter.create( + loc, val, ptr); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); + } + + // 2. Continuous masked stores. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) + return failure(); + + auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); + auto dstSubview = mstate.getSubview(ptr, loc, rewriter); + + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct LoopConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrAnalysis::IndexMapSet + levelToBlockArgIndex; // level -> set of block arg index to be replaced + + PtrAnalysis::rewriteForOp(op, rewriter, levelToBlockArgIndex, 0, knownPtrs); + return success(); + } +}; + +struct YieldConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +// Remove all Meta ops except for AddPtr which is handled by AddPtrConverter. +// Use benefit == 10 to ensure that this pattern always takes precedence over +// other patterns. +struct MetaOpConverter : public RewritePattern { +private: + // UseAnalysis will tag operations whose results are used only as meta-data + // with "MetaUse" tag. + bool isMetaUse(Operation *op) const { return op->hasAttr("MetaUse"); } + +public: + MetaOpConverter(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + + if (isa(op)) { + return rewriter.notifyMatchFailure(op, + "AddPtrOp will be handled separately"); + } + + if (isMetaUse(op)) { + rewriter.eraseOp(op); + return success(); + } + + return rewriter.notifyMatchFailure(op, "requires meta ops"); + } +}; + +struct UnrealizedCastConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +//----------------------------- +// End of monolithic only +//----------------------------- + +struct SplatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opType = cast(op.getType()); + auto loc = op.getLoc(); + + auto init = rewriter.create(loc, opType.getShape(), + opType.getElementType()); + + auto filledTensor = + rewriter + .create(loc, ValueRange{adaptor.getSrc()}, + ValueRange{init}) + .result(); + + rewriter.replaceOp(op, filledTensor); + return success(); + } +}; + +struct BroadcastConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst) const { + SmallVector broadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (dstShape[i] != srcShape[i]) { + assert(srcShape[i] == 1); + broadcastDims.push_back(i); + } + } + assert(!broadcastDims.empty() && "cannot identify broadcast dimension"); + return broadcastDims; + } + + // Broadcasts input tensor based on TosaToLinalg's broadcastToShape + AffineMap getBroadcastAffineMap(MLIRContext *context, + ArrayRef inputShape, + ArrayRef broadcastToShape) const { + + assert(broadcastToShape.size() >= inputShape.size()); + + // Create affine map and shapes for tensor initialization. + SmallVector outExpr; + + size_t diff = broadcastToShape.size() - inputShape.size(); + for (size_t i = 0; i < broadcastToShape.size(); i++) { + if (i < diff) { + continue; + } + size_t j = i - diff; + if (inputShape[j] == 1) { + // Broadcast singleton dimension + outExpr.push_back(mlir::getAffineConstantExpr(0, context)); + continue; + } + // Non-broadcast case + outExpr.push_back(mlir::getAffineDimExpr(i, context)); + } + return AffineMap::get(broadcastToShape.size(), 0, outExpr, context); + } + +public: + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + assert(op->getNumResults() == 1 && "code assumes single result!"); + RankedTensorType sourceType = + cast(adaptor.getSrc().getType()); + RankedTensorType resultType = cast(op.getType()); + auto elementType = resultType.getElementType(); + size_t resultRank = resultType.getRank(); + + SmallVector indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + indexingMaps.push_back(getBroadcastAffineMap( + op->getContext(), sourceType.getShape(), resultType.getShape())); + indexingMaps.append(op->getNumResults(), + rewriter.getMultiDimIdentityMap(resultRank)); + + assert(op->getNumResults() == 1 && "code assumes single result!"); + auto init = rewriter.create(loc, resultType.getShape(), + elementType); + + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = blockArgs[0]; + nestedBuilder.create(loc, opResult); + }); + + linalgOp->setAttr("broadcastDims", + rewriter.getDenseI64ArrayAttr( + getBroadcastDims(sourceType, resultType))); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + +struct ExpandDimsConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto src = adaptor.getSrc(); + auto srcRank = cast(src.getType()).getRank(); + auto resType = cast(op->getResultTypes()[0]); + SmallVector reassoc; + int64_t c = 0; + for (int64_t i = 0; i < srcRank; i++) { + ReassociationIndices g; + g.push_back(c++); + if (op.getAxis() == i) { + g.push_back(c++); + } else if (op.getAxis() == i + 1 && i == srcRank - 1) { + g.push_back(c++); + } + reassoc.push_back(g); + } + + auto expandShapeOp = rewriter.create( + op.getLoc(), resType, src, reassoc); + + rewriter.replaceOp(op, expandShapeOp.getResult()); + return success(); + } +}; + +struct TransposeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto src = adaptor.getSrc(); + auto srcRank = cast(src.getType()).getRank(); + assert(srcRank == 2 && "only expect transposing 2D data"); + + auto res = getTransposedValue(src, op.getLoc(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct MakeRangeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = cast(op.getResult().getType()); + auto shape = type.getShape(); + auto elementType = type.getElementType(); + auto context = rewriter.getContext(); + + assert(type.getShape().size() == 1 && + type.getElementType().getIntOrFloatBitWidth() == 32 && + "make range can only return 1D int32 tensor"); + + SmallVector indexingMaps{AffineMap::get( + /* dimCount */ 1, /* symbolCount */ 0, + SmallVector{mlir::getAffineDimExpr(0, context)}, context)}; + + auto init = rewriter.create(loc, shape, elementType); + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), /* operands */ ValueRange{}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(1), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value index = nestedBuilder.create(loc, 0); + Value res = nestedBuilder.create( + loc, type.getElementType(), index); + nestedBuilder.create(loc, res); + }); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + +struct AssertConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value condVal = op.getCondition(); + + if (isa(condVal.getType())) { + auto scalarVal = getScalarValue(op.getCondition(), op.getLoc(), rewriter); + condVal = scalarVal ? scalarVal : condVal; + } + assert(condVal && isa(condVal.getType()) && + "Only asserts on scalars are currently supported"); + + if (!condVal.getType().isInteger(1)) { + auto zero = + rewriter.create(op.getLoc(), 0, 32); + auto newCond = rewriter.create( + op.getLoc(), arith::CmpIPredicate::ne, condVal, zero); + condVal = newCond.getResult(); + } + + auto assertMessage = llvm::formatv("FIXME: assertion!"); + rewriter.create(op.getLoc(), condVal, + assertMessage.str()); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct BitcastConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto arithBitcast = rewriter.create( + op.getLoc(), op.getType(), op.getOperand()); + + rewriter.replaceOp(op, arithBitcast.getResult()); + return success(); + } +}; + +struct CallConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector args = adaptor.getOperands(); + + // We need to pass extra arguments added by addProgramInfo which are + // num_programs and program_ids + if (FuncOp parentFunc = op->getParentOfType()) { + SymbolRefAttr calleeAttr = op.getCalleeAttr(); + StringRef calleeName = calleeAttr.getRootReference(); + + if (ModuleOp module = op->getParentOfType()) { + if (FuncOp calleeFunc = module.lookupSymbol(calleeName)) { + size_t argsNeed = calleeFunc.getFunctionType().getInputs().size(); + Block &entryBlock = parentFunc.front(); + auto parentInputs = entryBlock.getArguments(); + size_t argsParent = parentInputs.size(); + + if (argsNeed > args.size()) { + int missing = argsNeed - args.size(); + for (int i = 0; i < missing; i++) { + args.push_back(parentInputs[args.size()]); + } + } + } + } + } + + auto call = rewriter.create(op.getLoc(), op.getCallee(), + op.getResultTypes(), args); + + if (!call) { + op.emitError("Failed to create func::CallOp"); + return failure(); + } + + rewriter.replaceOp(op, call); + return success(); + } +}; + +struct FpToFpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundingMode = triton::RoundingMode::RTNE; // default + + auto roundingModeAttr = op.getRounding(); + if (roundingModeAttr.has_value()) { + roundingMode = roundingModeAttr.value(); + } + + assert(roundingMode != triton::RoundingMode::RTZ && + "Rounding Towards Zero is not supported"); + + Type resultType = op.getResult().getType(); + + auto operandWidth = getBitWidth(op.getOperand().getType()); + auto resultWidth = getBitWidth(resultType); + + assert(operandWidth.has_value() && resultWidth.has_value() && + "Not a float-like operand or result"); + + if (operandWidth.value() > resultWidth.value()) { + Value truncatedValue = rewriter.create( + op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, truncatedValue); + return success(); + } + + Value extendedValue = rewriter.create( + op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, extendedValue); + + return success(); + } +}; + +struct ClampConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL; + + assert(!propagateNan && "PropagateNan is not supported"); + + Location loc = op.getLoc(); + Value x = adaptor.getOperands()[0]; + Value min = adaptor.getOperands()[1]; + Value max = adaptor.getOperands()[2]; + + Value maxMin = rewriter.create(loc, x, min); + Value clamp = rewriter.create(loc, maxMin, max); + rewriter.replaceOp(op, clamp); + + return success(); + } +}; + +struct PreciseSqrtConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = + rewriter.create(op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + +struct PreciseDivConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = + rewriter.create(op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + +struct CatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + + return success(); + } +}; + +struct SplitConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getOperand(); + auto inputType = cast(input.getType()); + + Type resultType = op.getResults().front().getType(); + auto resultTensor = cast(resultType); + auto shape = inputType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = llvm::to_vector( + llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + SmallVector results; + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + Value slice = rewriter.create( + loc, resultTensor, input, offsets, sizes, strides); + results.push_back(slice); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct JoinConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange inputs = op.getOperands(); + + auto resultType = cast(op.getResult().getType()); + + auto loc = op.getLoc(); + Value result = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + auto shape = resultType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = llvm::to_vector( + llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + result = rewriter.create(loc, inputs[i], result, + offsets, sizes, strides); + } + + rewriter.replaceOp(op, result); + + return success(); + } +}; + +struct MulHiUIOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto mulResult = + rewriter.create(loc, adaptor.getOperands()); + rewriter.replaceOp(op, mulResult.getHigh()); + + return success(); + } +}; + +// TODO: Move this MatmulConverter to MK related folder as it converts +// triton::DotOp directly into mk::DotOp which carries more information than +// linalg.matmul. +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // true means tensor elements are zeros + // false means not zero or it cannot be determined + bool isZeroTensor(Value &v, bool integers) const { + if (auto splatOp = v.getDefiningOp()) { + if (auto constOp = splatOp.getSrc().getDefiningOp()) { + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValueAsDouble() == 0.; + } + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValue() == 0; + } + } + return false; + } + + if (auto constOp = v.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat()) { + if (integers) + return denseAttr.getSplatValue().isZero(); + return denseAttr.getSplatValue().isZero(); + } + } + } + + return false; + } + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto opa = op.getA(); + auto opb = op.getB(); + auto opc = op.getC(); + + auto dstType = cast(op.getType()); + auto elementType = dstType.getElementType(); + bool integers = elementType.isInteger(); + + auto init = + rewriter.create(loc, dstType.getShape(), elementType); + TypedAttr constantAttr = + integers + ? static_cast(rewriter.getIntegerAttr(elementType, 0)) + : static_cast(rewriter.getFloatAttr(elementType, 0)); + + auto zero = rewriter.create( + op.getLoc(), elementType, constantAttr); + + auto zeroes = + rewriter.create(loc, ValueRange{zero}, ValueRange{init}) + .result(); + + auto dotOp = rewriter.create(loc, dstType, + ValueRange{opa, opb, opc, zeroes}); + + rewriter.replaceOp(op, dotOp); + + return success(); + } +}; + +struct ReduceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + llvm::SmallVector getRedOps(triton::ReduceOp redOp) const { + auto reduceBlock = redOp.getBody(); + return llvm::map_to_vector(reduceBlock->without_terminator(), + [](Operation &op) { return &op; }); + } + + bool isReductionOpSupported(Operation *redOp) const { + return isa( + redOp); + } + + arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, + Type constantType) const { + const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); + + auto attr = + llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::AddIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, + attr); + } + + bool requiresF32Conversion(const Type elemType, Operation *redOp) const { + return isa(elemType) && + elemType.getIntOrFloatBitWidth() < + llvm::cast(Float32Type::get(elemType.getContext())) + .getWidth() && + isa(redOp); + } + + Value getRedElement(Value lhs, Value rhs, const Location loc, + Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const { + return llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + if (convertLhsToF32Precision) { + lhs = b.create(loc, Float32Type::get(b.getContext()), + lhs); + } + return b.create(loc, lhs, rhs); + }) + .Case([&](auto redOp) { + return b.create(loc, lhs, rhs); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + } + + LogicalResult + convertToLinalgReduce(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto source = adaptor.getOperands().front(); + auto sourceType = cast(source.getType()); + auto elemType = sourceType.getElementType(); + auto resType = op.getResult().front().getType(); + auto loc = op.getLoc(); + auto reductionOps = getRedOps(op); + + // Reduction of arbitrary operations isn't supported because using the first + // element across the reduction dimension requires us to iterate over a + // subview that skips over each first element. + if (reductionOps.size() != 1 || + !isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with body " + "containing 1 max(i/f) or addf."); + } + + auto rop = reductionOps.front(); + auto axis = op.getAxis(); + auto isVectorReduce = sourceType.getRank() == 1; + + if (axis == sourceType.getRank() - 1 && !isVectorReduce) { + source = getTransposedValue(source, op.getLoc(), rewriter); + axis = sourceType.getRank() - 2; + } + + bool convertToF32Precision = requiresF32Conversion(resType, rop); + + auto constantType = convertToF32Precision + ? Float32Type::get(rewriter.getContext()) + : elemType; + + auto accBaseConstOp = getRedBaseConstOp(rewriter, rop, constantType); + Value initTensor; + + if (isVectorReduce) { + // The affine vectorizer cannot vectorize affine loops generated from + // linalg.reduce for the vector reduce case, so we must rewrite the + // linalg.reduce to affine loops manually. Here we lower to AllocTensor + // directly instead of EmptyOp so that the subsequent pass can recognize + // the patterns (EmptyOp is susceptible to being CSE'd away, making it + // harder to match the patterns correctly). + initTensor = rewriter.create( + loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = rewriter.create(loc, accBaseConstOp, + initTensor, ValueRange{}); + } else { + Value init = rewriter.create( + loc, cast(resType).getShape(), constantType); + initTensor = rewriter + .create(loc, ValueRange{accBaseConstOp}, + ValueRange{init}) + .result(); + } + + Value finalResult = + rewriter + .create( + loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = + getRedElement(inputs[0], inputs[1], loc, rop, opBuilder, + convertToF32Precision); + opBuilder.create(loc, result); + }) + .getResult(0); + + if (sourceType.getRank() == 1) { + finalResult = + rewriter.create(loc, constantType, finalResult); + } + + if (convertToF32Precision) { + finalResult = rewriter.create(loc, resType, finalResult); + } + + rewriter.replaceOp(op, finalResult); + return success(); + } + +public: + LogicalResult + matchAndRewrite(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = + cast(adaptor.getOperands().front().getType()); + assert(sourceType.hasRank() && "Expected input is " + "ranked"); + + int64_t axis = op.getAxis(); + assert(axis >= 0 && axis < sourceType.getRank() && + "Expected reduction " + "axis is within " + "operand's rank"); + + return convertToLinalgReduce(op, adaptor, rewriter); + } +}; + +template +class ArgMinMaxBaseConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // We're looking for an op that looks like this: + // + // %9:2 = "tt.reduce"(%8, %3) <{axis = 0 : i32}> ({ + // ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + // ------------------------------------------------- + // `matchTieBreakValue` | + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 | + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 | 1. + // %13 = arith.andi %11, %12 : i1 | + // ------------------------------------------------- |-> `matchShouldUpdate` + // `matchUpdateCondition` | + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 | 2. + // ------------------------------------------------- | + // %15 = arith.ori %14, %13 : i1 | + // ------------------------------------------------- + // %16 = arith.select %15, %arg9, %arg11 : f32 + // %17 = arith.select %15, %arg10, %arg12 : i32 + // tt.reduce.return %16, %17 : f32, i32 + // }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + // + // The above mlir code is lowered from this combinator in triton's + // standard.py: + // + // def _argmax_combine(value1, index1, value2, index2, tie_break_left): + // if tie_break_left: + // tie = value1 == value2 and index1 < index2 + // else: + // tie = False + // gt = value1 > value2 or tie + // v_ret = core.where(gt, value1, value2) + // i_ret = core.where(gt, index1, index2) + // return v_ret, i_ret + + LogicalResult matchTieBreakResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &tileBreakValue) const { + // Match the following (section 1. of the above) + // + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 + // %13 = arith.andi %11, %12 : i1 + // + // which is equivalent to the following python code + // + // tie = value1 == value2 and index1 < index2 + + // matching: %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto eqCmpOp = dyn_cast(*it++); + if (eqCmpOp) { + if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ) { + return failure(); + } + if (currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + // matching: %12 = arith.cmpi slt, %arg10, %arg12 : i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto sltCmpOp = dyn_cast(*it++); + if (sltCmpOp) { + if (sltCmpOp.getPredicate() != arith::CmpIPredicate::slt) { + return failure(); + } + if (currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + // matching: %13 = arith.andi %11, %12 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto andOp = dyn_cast(*it++); + if (andOp) { + if (andOp.getLhs() != eqCmpOp || andOp.getRhs() != sltCmpOp) { + return failure(); + } + } else { + return failure(); + } + + tileBreakValue = andOp; + return success(); + } + + LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &shouldUpdate) const { + Value tieResult; + if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, + reduceIndex, it, tieResult))) { + LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); + return failure(); + } + + Value comparisonResult; + if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, + reduceIndex, it, comparisonResult))) { + LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); + return failure(); + } + + // matching: %15 = arith.ori %14, %13 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto orOp = dyn_cast(*it++); + if (orOp) { + if (orOp.getLhs() != comparisonResult || orOp.getRhs() != tieResult) { + return failure(); + } + } else { + return failure(); + } + + shouldUpdate = orOp; + return success(); + } + + Value getInitTensor(ConversionPatternRewriter &rewriter, + ArrayRef shape, Value fillValue, + Location loc) const { + Value initTensor = + rewriter.create(loc, shape, fillValue.getType()); + return rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{initTensor}) + .result(); + } + +public: + ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult match(ReduceOp op) const override final { + if (op.getBody()->getNumArguments() != 4) { + return failure(); + } + + auto block = op.getBody(); + auto ops = block->without_terminator(); + + Value currValue = block->getArgument(0); + Value currIndex = block->getArgument(1); + Value reduceValue = block->getArgument(2); + Value reduceIndex = block->getArgument(3); + + auto opsIt = ops.begin(); + Value shouldUpdate; + if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, + reduceIndex, opsIt, shouldUpdate))) { + return failure(); + } + + // matching: %16 = arith.select %15, %arg9, %arg11 : f32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto valueSelectOp = dyn_cast(*opsIt++); + if (valueSelectOp) { + if (valueSelectOp.getCondition() != shouldUpdate || + currValue != valueSelectOp.getTrueValue() || + reduceValue != valueSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + + // matching:%17 = arith.select %15, %arg10, %arg12 : i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto indexSelectOp = dyn_cast(*opsIt++); + if (indexSelectOp) { + if (indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + + // matching: tt.reduce.return %16, %17 : f32, i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto termOp = dyn_cast(*opsIt++); + if (termOp && termOp == block->getTerminator()) { + auto opnds = termOp.getOperands(); + if (opnds != ArrayRef{valueSelectOp, indexSelectOp}) { + return failure(); + } + } else { + return failure(); + } + + return success(); + } + + void rewrite(ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto loc = op.getLoc(); + + auto elemTypes = op.getElementTypes(); + + // Set the initial value of the rank-0 tensor containing + // the result value to either -inf or +inf depending on + // whether we're dealing with argmax or argmin + auto valueType = elemTypes[0]; + auto valuesAccBaseVal = rewriter.create( + loc, valueType, + rewriter.getFloatAttr(valueType, T::getBaseReductionValue())); + + // Set the initial value of the rank-0 tensor containing the index of the + // min or max value to -1 + auto indexType = elemTypes[1]; + auto indicesAccBaseVal = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + + // Get the shape of the resulting tensors (both for values and indices). If + // we are reducing to a single scalar, then the result's type is a tensor of + // rank-0, otherwise we can reuse the original result shape + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + SmallVector reductionResultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(valueResultType.getShape())}; + + SmallVector outputs{ + getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), + getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + assert(inputs.size() == 4); + + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { + return mapping.lookup(val); + }); + b.create(loc, results); + }); + + if (isScalarReduce) { + SmallVector reduceResults{ + rewriter.create( + loc, valueType, linalgOp.getResults()[0], ValueRange{}), + rewriter.create( + loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + } +}; + +struct ArgMaxConverter : public ArgMinMaxBaseConverter { + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult) { + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + // This corresponds to section 2. of the sample snippet in + // ArgMinMaxBaseConverter + auto cmpOp = dyn_cast(*it++); + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + comparisonResult = cmpOp; + return success(); + } + + static float getBaseReductionValue() { + return -std::numeric_limits::infinity(); + } + + ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +struct ArgMinConverter : public ArgMinMaxBaseConverter { + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult) { + // %14 = arith.cmpf olt, %arg9, %arg11 : f32 + // This corresponds to section 2. of the sample snippet in + // ArgMinMaxBaseConverter + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto cmpOp = dyn_cast(*it++); + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + comparisonResult = cmpOp; + return success(); + } + + static float getBaseReductionValue() { + return std::numeric_limits::infinity(); + } + + ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +// get_program_id and get_num_programs: +// When launching triton kernels, we pass 6 additional arguments to indicate +// num_programs and program_id. Amongst those six, we have 3 arguments +// correspond to each axis for num_programs followed by 3 additional arguments +// for program_id. +// +// For instance, with triton kernel example_kernel(a, b, c), we have: +// example_kernel( +// a, b, c, +// num_programs_axis_0, +// num_programs_axis_1, +// num_programs_axis_2, +// program_id_axis_0, +// program_id_axis_1, +// program_id_axis_2, +// ) +// +struct GetProgramIDConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto axis = (uint32_t)op.getAxis(); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - LAUNCH_GRID_RANK + axis); + + rewriter.replaceOp(op, id); + return success(); + } +}; + +struct GetNumProgramsConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + GetNumProgramsConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto axis = (uint32_t)op.getAxis(); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - LAUNCH_GRID_RANK * 2 + axis); + + rewriter.replaceOp(op, id); + return success(); + } +}; + +// Convert a pair of cmpf and select to either min or max. +// Leave the pattern as simple as possible because triton has plans to emit +// min and max directly. +template +struct MinMaxConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + MinMaxConverter(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/10) {} + + LogicalResult matchAndRewrite(CmpOp cmpOp, + PatternRewriter &rewriter) const final { + if (!cmpOp.getResult().hasOneUse()) { + return failure(); + } + auto selectOp = + dyn_cast(*cmpOp.getResult().getUsers().begin()); + if (!selectOp) { + return failure(); + } + + if (!(cmpOp.getResult() == selectOp.getCondition() && + cmpOp.getLhs() == selectOp.getTrueValue() && + cmpOp.getRhs() == selectOp.getFalseValue())) { + return failure(); + } + + rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); + rewriter.eraseOp(cmpOp); + + return success(); + } + + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, + arith::SelectOp selectOp, + arith::CmpFPredicate pred) const { + switch (pred) { + case arith::CmpFPredicate::OGT: + case arith::CmpFPredicate::OGE: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpFPredicate::OLT: + case arith::CmpFPredicate::OLE: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + default: + llvm_unreachable("Unhandled predicate"); + } + } + + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp, + arith::SelectOp selectOp, + arith::CmpIPredicate pred) const { + switch (pred) { + case arith::CmpIPredicate::sgt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ugt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::slt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ult: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + default: + llvm_unreachable("Unhandled predicate"); + } + } +}; + +struct DenseConstantConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto attr = cast(op.getValue()); + auto loc = op.getLoc(); + + auto splatConst = arith::ConstantOp::materialize( + rewriter, attr.getSplatValue(), attr.getElementType(), loc); + + auto init = rewriter.create( + loc, cast(op.getResult().getType()).getShape(), + attr.getElementType()); + + rewriter.replaceOpWithNewOp(op, ValueRange{splatConst}, + ValueRange{init}); + + return success(); + } +}; + +class CumSumConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // CumSum is a specific instance of Scan that looks like the following: + // %1 = "tt.scan"(%0) <{axis = 1 : i32}> ({ + // ^bb0(%arg0: f32, %arg1: f32): + // %2 = arith.addf %arg0, %arg1 : f32 + // tt.scan.return %2 : f32 + // }) : (tensor<4x4xf32>) -> tensor<4x4xf32> + bool isCumSum(triton::ScanOp op) const { + auto scanBlock = op.getBody(); + auto ops = llvm::map_to_vector(scanBlock->without_terminator(), + [](Operation &op) { return &op; }); + + if (ops.size() != 1) { + return false; + } + + auto addOp = ops.front(); + if (isa(addOp)) { + if (addOp->getResult(0) != scanBlock->getTerminator()->getOperand(0)) { + return false; + } + + auto blockArgs = + llvm::map_range(scanBlock->getArguments(), [](BlockArgument arg) { + return dyn_cast(arg); + }); + + auto addArgs = addOp->getOperands(); + + return DenseSet(blockArgs.begin(), blockArgs.end()) == + DenseSet(addArgs.begin(), addArgs.end()); + } + + return false; + } + +public: + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isCumSum(op)) { + return rewriter.notifyMatchFailure( + op, "Only support cumsum variant of scan op"); + } + + auto input = op.getOperand(0); + auto axis = op.getAxis(); + auto type = dyn_cast(input.getType()); + + if (type.getRank() != 1 && type.getRank() != 2 && + axis != type.getRank() - 1) { + return rewriter.notifyMatchFailure( + op, "Only support lowering scan op to cumsum with rank " + "= {1, 2} and axis = rank - 1"); + } + + Value init = rewriter.create(op.getLoc(), type.getShape(), + type.getElementType()); + + rewriter.replaceOpWithNewOp( + op, input, rewriter.getUI32IntegerAttr(axis), init); + + return success(); + } +}; + +class AddPtrConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resType = op.getResult().getType(); + assert(isa(resType)); + auto rank = cast(resType).getRank(); + SmallVector indexingMaps( + /*numResult + numOperands*/ 3, rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + SmallVector outputs = {op.getPtr()}; + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps, + iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return cast(type).getElementType(); + })); + auto *scalarOp = + builder.create(loc, op->getName().getIdentifier(), + regionArgs.take_front(op->getNumOperands()), + resultTypes, op->getAttrs()); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; + +class ReshapeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getSrc(); + auto output = op.getResult(); + + auto inputType = input.getType(); + auto outputType = output.getType(); + if (!outputType.hasStaticShape()) { + return failure(); + } + + if (auto maybeReassociationMap = + getReassociationIndicesForReshape(inputType, outputType)) { + auto reassociationMap = *maybeReassociationMap; + if (outputType.getRank() < inputType.getRank()) { + rewriter.replaceOpWithNewOp( + op, outputType, input, reassociationMap); + } else { + rewriter.replaceOpWithNewOp( + op, outputType, input, reassociationMap); + } + return success(); + } + + ArrayRef outputShape = outputType.getShape(); + + auto shape = rewriter.create( + loc, rewriter.getI64TensorAttr(outputShape)); + rewriter.replaceOpWithNewOp(op, outputType, input, + shape); + + return success(); + } +}; + +class ExternElementwiseBinaryOpConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + if (!op.getPure() || op.getSrcs().size() != 2) + return failure(); +#define POPULATE_BINARY_OP(FUNC_NAME, DST_OP) \ + if (!op.getSymbol().compare(FUNC_NAME)) { \ + rewriter.replaceOpWithNewOp(op, op.getSrcs()[0], op.getSrcs()[1]); \ + return success(); \ + } + + POPULATE_BINARY_OP("__nv_atan2f", math::Atan2Op); + POPULATE_BINARY_OP("__nv_atan2", math::Atan2Op); + POPULATE_BINARY_OP("__nv_powf", math::PowFOp); + POPULATE_BINARY_OP("__nv_pow", math::PowFOp); + +#undef POPULATE_BINARY_OP + return failure(); + } +}; + +class ExternElementwiseUnaryOpConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + if (!op.getPure() || op.getSrcs().size() != 1) + return failure(); +#define POPULATE_UNARY_OP(FUNC_NAME, DST_OP) \ + if (!op.getSymbol().compare(FUNC_NAME)) { \ + rewriter.replaceOpWithNewOp(op, op.getSrcs()[0]); \ + return success(); \ + } + + POPULATE_UNARY_OP("__nv_fabsf", math::AbsFOp); + POPULATE_UNARY_OP("__nv_fabs", math::AbsFOp); + POPULATE_UNARY_OP("__nv_sinf", math::SinOp); + POPULATE_UNARY_OP("__nv_sin", math::SinOp); + POPULATE_UNARY_OP("__nv_cosf", math::CosOp); + POPULATE_UNARY_OP("__nv_cos", math::CosOp); + POPULATE_UNARY_OP("__nv_tanf", math::TanOp); + POPULATE_UNARY_OP("__nv_tan", math::TanOp); + POPULATE_UNARY_OP("__nv_asinf", math::AsinOp); + POPULATE_UNARY_OP("__nv_asin", math::AsinOp); + POPULATE_UNARY_OP("__nv_acosf", math::AcosOp); + POPULATE_UNARY_OP("__nv_acos", math::AcosOp); + POPULATE_UNARY_OP("__nv_atanf", math::AtanOp); + POPULATE_UNARY_OP("__nv_atan", math::AtanOp); + POPULATE_UNARY_OP("__nv_sinhf", math::SinhOp); + POPULATE_UNARY_OP("__nv_sinh", math::SinhOp); + POPULATE_UNARY_OP("__nv_coshf", math::CoshOp); + POPULATE_UNARY_OP("__nv_cosh", math::CoshOp); + POPULATE_UNARY_OP("__nv_tanhf", math::TanhOp); + POPULATE_UNARY_OP("__nv_tanhf", math::TanhOp); + POPULATE_UNARY_OP("__nv_acoshf", math::AcoshOp); + POPULATE_UNARY_OP("__nv_acosh", math::AcoshOp); + POPULATE_UNARY_OP("__nv_asinhf", math::AsinhOp); + POPULATE_UNARY_OP("__nv_asinh", math::AsinhOp); + POPULATE_UNARY_OP("__nv_atanhf", math::AtanhOp); + POPULATE_UNARY_OP("__nv_atanhf", math::AtanhOp); + POPULATE_UNARY_OP("__nv_logf", math::LogOp); + POPULATE_UNARY_OP("__nv_log", math::LogOp); + POPULATE_UNARY_OP("__nv_log10f", math::Log10Op); + POPULATE_UNARY_OP("__nv_log10", math::Log10Op); + POPULATE_UNARY_OP("__nv_log1pf", math::Log1pOp); + POPULATE_UNARY_OP("__nv_log1p", math::Log1pOp); + POPULATE_UNARY_OP("__nv_expf", math::ExpOp); + POPULATE_UNARY_OP("__nv_exp", math::ExpOp); + POPULATE_UNARY_OP("__nv_exp2f", math::Exp2Op); + POPULATE_UNARY_OP("__nv_exp2", math::Exp2Op); + POPULATE_UNARY_OP("__nv_erff", math::ErfOp); + POPULATE_UNARY_OP("__nv_erf", math::ErfOp); + POPULATE_UNARY_OP("__nv_sqrtf", math::SqrtOp); + POPULATE_UNARY_OP("__nv_sqrt", math::SqrtOp); + POPULATE_UNARY_OP("__nv_rsqrtf", math::RsqrtOp); + POPULATE_UNARY_OP("__nv_rsqrt", math::RsqrtOp); + POPULATE_UNARY_OP("__nv_ceilf", math::CeilOp); + POPULATE_UNARY_OP("__nv_ceil", math::CeilOp); + POPULATE_UNARY_OP("__nv_floorf", math::FloorOp); + POPULATE_UNARY_OP("__nv_floor", math::FloorOp); + POPULATE_UNARY_OP("__nv_truncf", math::TruncOp); + POPULATE_UNARY_OP("__nv_trunc", math::TruncOp); + +#undef POPULATE_UNARY_OP + return failure(); + } +}; + +static void populateExternElementwiseOpToMLIROps(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h new file mode 100644 index 000000000..b95cbde73 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td new file mode 100644 index 000000000..590c02de7 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td @@ -0,0 +1,20 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonArithToLinalg : Pass<"triton-arith-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton arithmetic operations into linalg"; + let options = [ + Option<"pidsToFuncArgs", "pids-to-func-args", "bool", /*default*/"true", + "Convert tt.get_program_id and tt.get_num_programs to reference to function arguments">, + Option<"ttToFuncFunc", "tt-to-func-func", "bool", /*default*/"true", + "Convert tt.func to func.func">, + Option<"addptrToLinalg", "addptr-to-linalg", "bool", /*default*/"true", + "Convert tt.addptr on tensors to linalg">, + Option<"assertToCf", "assert-to-cf", "bool", /*default*/"true", + "Convert tt.assert to cf.assert">, + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h new file mode 100644 index 000000000..8e5e5822a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h @@ -0,0 +1,28 @@ +#ifndef TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H +#define TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +void populateTritonArithToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTritonArithToLinalgConversionPatterns(bool pidsToFuncArgs, + bool addptrToLinalg, + bool assertToCf, + RewritePatternSet &patterns); + +std::unique_ptr> createTritonArithToLinalgPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt new file mode 100644 index 000000000..3cc51fcb2 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt @@ -0,0 +1,9 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToCoreDialects) +add_public_tablegen_target(TritonToCoreDialectsConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h new file mode 100644 index 000000000..32fc0104d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES_H +#define TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td new file mode 100644 index 000000000..6d10cfb6f --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES +#define TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToCoreDialects : Pass<"triton-to-core-dialects", "mlir::ModuleOp"> { + let summary = "Convert Triton to core dialects including Linalg, Memref etc"; + let constructor = "triton::createTritonToCoreDialectsPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h new file mode 100644 index 000000000..d968cc055 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// +// +// This pass is the wrapall pass that populates all the conversion patterns from +// triton to core dialects such as linalg, memref, buf etc. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H +#define TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToCoreDialectsPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..74ccdd390 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,9 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) +add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h new file mode 100644 index 000000000..404af0802 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_TO_LINALG_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td new file mode 100644 index 000000000..627077e3a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES +#define TRITON_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton to Linalg dialect"; + let constructor = "triton::createTritonToLinalgPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h new file mode 100644 index 000000000..4c58e9921 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H +#define TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToLinalgPass(); + +void populateTritonToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + unsigned int launchGridRank); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 000000000..5762c1f69 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToStructured) +add_public_tablegen_target(TritonToStructuredConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h new file mode 100644 index 000000000..3c3b81ca4 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES_H +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td new file mode 100644 index 000000000..e2702464b --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -0,0 +1,19 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { + let summary = "Convert Triton non-block pointer to TritonStructured dialect"; + let constructor = "triton::createTritonToStructuredPass()"; + let options = [ + Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false", + "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">, + Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false", + "Skip the prepass">, + Option<"useUnsafeMask", "use-unsafe-mask", "bool", /*default*/"false", + "Assume that the mask bounds are never less than starting offsets. May produce incorrect results."> + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h new file mode 100644 index 000000000..0ee1a6d53 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H +#define TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToStructuredPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt new file mode 100644 index 000000000..68066ab63 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 000000000..9c32c97c8 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TritonStructuredDialect.td) +mlir_tablegen(TritonStructuredDialect.h.inc -gen-dialect-decls -dialect=tts) +mlir_tablegen(TritonStructuredDialect.cpp.inc -gen-dialect-defs -dialect=tts) +mlir_tablegen(TritonStructuredOps.h.inc -gen-op-decls) +mlir_tablegen(TritonStructuredOps.cpp.inc -gen-op-defs) + + +add_public_tablegen_target(TritonStructuredTableGen) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h new file mode 100644 index 000000000..bd01afd05 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h @@ -0,0 +1,27 @@ +#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ +#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// TritonStructured Operations +//===----------------------------------------------------------------------===// +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td new file mode 100644 index 000000000..c0f89bfc1 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -0,0 +1,213 @@ +#ifndef TRITON_STRUCTURED_DIALECT +#define TRITON_STRUCTURED_DIALECT + +include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Triton_Structured_Dialect : Dialect { + let name = "tts"; + + let cppNamespace = "::mlir::tts"; + + let summary = "Structured Triton operations"; + + let description = [{ + Triton Structured Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect" + ]; + + let usePropertiesForAttributes = 1; +} + +// +// Op Base +// +class TTS_Op traits = []> : + Op { +} + +def TTS_MakeTensorPtrOp + : TTS_Op<"make_tptr", [AttrSizedOperandSegments, Pure]> { + let summary = "create a pointer that points to a tensor in memory"; + + // base: Base pointer used to contruct the tensor of pointers or pointer to tensor. + // sizes: Size of the data being loaded or stored. + // strides: The strides of the parent tensor, which means how much to increase the pointer + // by when moving by 1 element in a specific axis. + // order: The order of the block, which means how the block is laid out in memory. + // It contains the same info as order in tt.make_tensor_ptr. + // shape: If order is present, this field signifies the shape of the parent tensor in + // memory; if order is not present, it signifies the boundary by which addresses + // wraps around (constant zero indicates no wrap-around in the corresponding dimension). + // offsets: Offset of the block along each dimension from base. + // result: If order is present, this op produces a pointer to a tensor; otherwise, + // it produces a tensor of pointers. + + let arguments = (ins TT_Ptr:$base, + DenseI64ArrayAttr:$sizes, + Variadic:$strides, + Variadic:$offsets, + Variadic:$shape, + DenseI64ArrayAttr:$static_strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_shape, + DenseI32ArrayAttr:$order); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = [{ + $base `to` `sizes` `` `:` $sizes + `` `,` `strides` `` `:` + custom($strides, $static_strides) + `` `,` `offsets` `` `:` + custom($offsets, $static_offsets) + `` `,` `shape` `` `:` + custom($shape, $static_shape) + `` `,` `order` `` `:` $order + attr-dict `:` type($base) `to` type($result) + }]; + + + let builders = [ + // Build with mixed static and dynamic entries. + OpBuilder<(ins + "Value":$base, + "ArrayRef":$sizes, + "ArrayRef":$strides, + "ArrayRef":$offsets, + "ArrayRef":$shape, + "ArrayRef":$order)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedSizes() { + Builder b(getContext()); + SmallVector dynSizes; // sizes are always static + return ::mlir::getMixedValues(getSizes(), dynSizes, b); + } + SmallVector getMixedStrides() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticStrides(), getStrides(), b); + } + SmallVector getMixedOffsets() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticOffsets(), getOffsets(), b); + } + SmallVector getMixedShape() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticShape(), getShape(), b); + } + bool isBlockPtr() { + return !getOrder().empty(); + } + bool isStructuredPtr() { + return !isBlockPtr() && + llvm::all_of(getStaticShape(), [](auto shape) { return shape == 0; }); + } + bool isSplitPtr() { + return !isBlockPtr() && + !isStructuredPtr(); + } + }]; + + // TODO + //let hasVerifier = 1; + //let hasCanonicalizer = 1; +} + +// SameVariadicResultSize +// AttrSizedResultSegments +def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> { + let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; + let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; + + let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input); + let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic:$offsets, Variadic:$strides); + + let builders = [ + OpBuilder<(ins "Value":$input)>, + ]; + + let extraClassDeclaration = [{ + static std::optional, SmallVector>> + getOffsetAndStrideTypes(MLIRContext *context, Type ptrLikeType); + + static std::optional> + getOffsetAndStrideSegmentSizes(Type ptrLikeType); + }]; + + let hasFolder = 0; + let hasVerifier = 1; +} + +def TTS_LoadOp : TTS_Op<"load", [ + MemoryEffects<[MemRead]>, + AttrSizedOperandSegments +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrLike:$ptr, + Variadic:$mask_dims, + DenseI64ArrayAttr:$static_mask_dims, + Optional>:$other); + + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "ArrayRef":$mask_dims, "Value":$other)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedMaskDims() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticMaskDims(), getMaskDims(), b); + } + + bool hasMask() { + return !getStaticMaskDims().empty(); + } + }]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +def TTS_StoreOp : TTS_Op<"store", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrLike:$ptr, + TT_Tensor:$value, + Variadic:$mask_dims, + DenseI64ArrayAttr:$static_mask_dims); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$dims)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedMaskDims() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticMaskDims(), getMaskDims(), b); + } + + bool hasMask() { + return !getStaticMaskDims().empty(); + } + }]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +#endif // TRITON_STRUCTURED_DIALECT diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt new file mode 100644 index 000000000..ba67b25a7 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS TritonTilingExtOps.td) +mlir_tablegen(TritonTilingExtOpsDialect.h.inc -gen-dialect-decls -dialect=ttx) +mlir_tablegen(TritonTilingExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=ttx) +mlir_tablegen(TritonTilingExtOps.h.inc -gen-op-decls) +mlir_tablegen(TritonTilingExtOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(TritonTilingExtOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonTilingExtInterfaces.td) +mlir_tablegen(TritonTilingExtInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TritonTilingExtInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(TritonTilingExtInterfacesIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h new file mode 100644 index 000000000..53e031db3 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ +#define MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" + +//===----------------------------------------------------------------------===// +// TritonTilingExt Operations +//===----------------------------------------------------------------------===// + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.h.inc" + +// Include the generated interface declarations. +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonTilingExt operations. +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.h.inc" + +namespace mlir { + +namespace ttx { + +// ----------------------------------------------------------------------------- +// BufferizableOpInterface +// ----------------------------------------------------------------------------- +// All TritonTilingExtOps need to support bufferization: the process of +// allocating buffers for tensors, thereby converting inputs and outputs of +// tensor type to memref. This process is done by implementing the +// "BufferizableOpInterface". We implement the interface for TritonTilingExtOps +// through an external model instead of directly in TritonTilingExtOps.td to be +// consistent with other ops in the mlir project. See some examples here: +// - mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +// - mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +// ----------------------------------------------------------------------------- +// TilingInterface +// ----------------------------------------------------------------------------- +// The three methods `getTiledImplementation`, `getResultTilePosition`, and +// `generateResultTileValue` are implemented as part of the TilingInterface. +// (see TilingInterface.td). These three methods are re-used across +// all TritonTilingExtOps, while others method are implemented individually by +// each operator depending on their use cases. +template +FailureOr getTiledImplementation(TritonTilingExtOpTy op, + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes); + +template +LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes); + +template +FailureOr +generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes); + +// ----------------------------------------------------------------------------- +// MemoryEffectsOpInterface +// ----------------------------------------------------------------------------- +// Implementation of the MemoryEffectsOpInterface for TritonTilingExtOps. +// This allows DCE pass to determine if a TritonTilingExtOp is safe to be +// removed. see TritonTilingExtOps.td for more details. +template +void getEffects( + TritonTilingExtOpTy op, + SmallVectorImpl> + &effects); + +// ----------------------------------------------------------------------------- +// Utilities +// ----------------------------------------------------------------------------- +// Utility method to extract a slice from the input source using either +// tensor::ExtractSlice or memref::SubView +Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides); + +} // namespace ttx +} // namespace mlir + +#endif // MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td new file mode 100644 index 000000000..e74fbb6cc --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES +#define MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES + +include "mlir/IR/OpBase.td" + +// +// Linalg operators require providing affine maps that define how input / output +// buffers are accessed together with a region that defines how each output +// element is computed; this requirement doesn't work well for operations such as +// `scan`. +// +// Fortunately, the introduction of the TilingInterface allows us to add tiling +// and fusion support to operations that don't fit into the linalg dialect. +// This fits our purpose perfectly: our `scan` operators can be treated as an +// "opaque" / "completely abstract" operation that can be tiled on the batch +// dimensions -- we don't need to provide any associated body together with it. +// +// However, this doesn't mean that we entirely forgo the "indexing map" concept. +// For example, consider the following: +// +// - ttx.scan ins(%1 : tensor<128x768xbf16>) +// outs(%2 : tensor<128x768xbf16>) -> tensor<128x768xbf16> +// +// Tiling the batch dimension gives us: +// +// for (i = 0 to 128) { +// %sliceIn = extract slice from input: tensor<1x768xbf16> +// %sliceOut = extract slice from output: tensor<1x768xbf16> +// %res = ttx.scan ins(slice : tensor<1x768xbf16>) +// outs(%2 : tensor<1x768xbf16>) -> tensor<1x768xbf16> +// insert %res into output +// } +// +// Now our `scan` op has the semantic of running `scan` on a rank-1 tensor and +// can be lowered further to other hardware-specific ops or external library +// calls. +// +// This tiling pattern is essentially the same as tiling a linalg.generic op +// with an identity map. The only difference is we don't need a body associated +// with our `scan` op. +// +// With this idea in mind, the TritonTilingExtInterface exposes methods +// that will be implemented individually by each TritonTilingExtOp, providing +// the indexing map for each input / output that can then be used to generate +// the correct slices during tiling and fusion. +// +// There might be other ops in the future that won't fit in this "indexing map" +// approach; we will consider making TritonTilingExtInterface an optional +// interface for such ops. +// + +def TritonTilingExtInterface : OpInterface<"TritonTilingExtInterface"> { + let cppNamespace = "::mlir::ttx"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the input operand with the given `index`. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getInputIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the output operand with the given `index`. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getOutputIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the operand with the given `index`. + This method returns the operand in order of inputs followed by outputs. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + > + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td new file mode 100644 index 000000000..d3a4268a5 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td @@ -0,0 +1,242 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TILING_EXT_BASE +#define TRITON_TILING_EXT_BASE + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + +include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td" + + +//===----------------------------------------------------------------------===// +// TritonTilingExt dialect definition +//===----------------------------------------------------------------------===// + +def TritonTilingExt_Dialect : Dialect { + let name = "ttx"; + let cppNamespace = "::mlir::ttx"; +} + +//===----------------------------------------------------------------------===// +// TritonTilingExt op definitions +//===----------------------------------------------------------------------===// + +// Base class for TritonTilingExt dialect ops. +class TritonTilingExt_Op traits = []> + : Op { +} + +class TritonTilingExt_TilingOp : Op, + // All TritonTilingExtOps implement TritonTilingExtInterface, which provides a standardized + // way of providing indexing maps for input and output operands. + DeclareOpInterfaceMethods, + + // MemoryEffectsOpInterface provides analysis passes such as DCE to determine + // whether an operation has no memory side effects and therefore is safe to + // be deleted. This interface is important during tile and fuse where we + // create copies of TilingInterface ops with smaller tile sizes but leave the + // original ops intact. + DeclareOpInterfaceMethods, + + // DestinationStyleOpInterface describes ops that have similar semantics to + // linalg ops, with a separate ins (input) and outs (output) operand groups. + // Implementing this op gives us access to a wide variety of useful methods + // to query the inputs and outputs of an op. + DestinationStyleOpInterface, + + // AttrSizedOperandSegments supports having multiple groups of operands. + // For example, linalg ops (as well as TritonTilingExtOps) all look like this: + // ttx.some_op ins(%1) outs(%2) -> resultType + AttrSizedOperandSegments +]> +{ + let results = (outs Variadic:$result_tensors); + + let hasCustomAssemblyFormat = 1; + + code baseClassDecls = [{ + // Implemented as part of DestinationStyleOpInterface + MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + }]; + + // Custom print() and parse() methods to make the TritonTilingExt ops have similar looks + // to the linalg ops. + // Borrowed from llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp + let extraClassDefinition = [{ + void $cppClass::print(OpAsmPrinter &p) { + p.printOptionalAttrDict(this->getOperation()->getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes"}); + + if (!getInputs().empty()) + p << " ins(" << getInputs() << " : " << getInputs().getTypes() << ")"; + if (!getOutputs().empty()) + p << " outs(" << getOutputs() << " : " << getOutputs().getTypes() << ")"; + + if (!getResultTypes().empty()) + p.printOptionalArrowTypeList(getResultTypes()); + } + + ParseResult $cppClass::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector inputTypes; + SmallVector outputTypes; + SMLoc inputsOperandsLoc, outputsOperandsLoc; + SmallVector inputsOperands, + outputsOperands; + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputsOperands, outputTypes, + outputsOperandsLoc, result.operands)) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + + SmallVector resultTypes; + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + result.addTypes(resultTypes); + + return success(); + } + + AffineMap $cppClass::getIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index < this->getNumOperands()); + if (index < getNumDpsInputs()) { + return getInputIndexingMap(context, index, sizes); + } + return getOutputIndexingMap(context, index - getNumDpsInputs(), sizes); + } + + // Forward each of the implementation to the shared implementation + FailureOr $cppClass::getTiledImplementation( + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes + ) { + return mlir::ttx::getTiledImplementation<$cppClass>( + *this, b, offsets, sizes + ); + } + + // Forward each of the implementation to the shared implementation + LogicalResult $cppClass::getResultTilePosition( + OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes + ) { + return mlir::ttx::getResultTilePosition<$cppClass>( + *this, b, resultNumber, offsets, sizes, resultOffsets, resultSizes + ); + } + + // Forward each of the implementation to the shared implementation + FailureOr $cppClass::generateResultTileValue( + OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes + ) { + return mlir::ttx::generateResultTileValue<$cppClass>( + *this, b, resultNumber, offsets, sizes + ); + } + + // Implemented as part of MemoryEffectsOpInterface + void $cppClass::getEffects( + SmallVectorImpl> + &effects + ) { + return mlir::ttx::getEffects<$cppClass>(*this, effects); + } + }]; +} + +def TritonTilingExt_CumSumOp : TritonTilingExt_TilingOp<"cumsum"> { + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + UI32Attr:$axis + ); + + let hasVerifier = 1; + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<(ins + "Value":$input, + "IntegerAttr":$axis, + "Value":$output, + CArg<"ArrayRef", "{}">:$attributes + )> + ]; + + let extraClassDeclaration = baseClassDecls # [{ + int64_t getRank() { + return cast(getInput().getType()).getRank(); + } + + Value getInput() { + return getInputs()[0]; + } + + Value getOutput() { + return getOutputs()[0]; + } + + static StringRef getAxisAttrStrName() { return "axis"; } + }]; +} + +#endif // TRITON_TILING_EXT_BASE diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt new file mode 100644 index 000000000..923f5b7e7 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(MKToTx81) +add_subdirectory(Tx81ToLLVM) +add_subdirectory(Tx81MemrefToLLVM) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt new file mode 100644 index 000000000..a69d0ceb2 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name MKToTx81) +add_public_tablegen_target(MKToTx81ConversionPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h new file mode 100644 index 000000000..29218895f --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h @@ -0,0 +1,36 @@ +//===------------------- MKToTx81.h ---------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering magic kernel ops to TsingMicro Tx81 target. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_MK_TO_TX81_H +#define ZTC_CONVERSION_MK_TO_TX81_H + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +void populateMKToTx81CanonicalizationPatterns(RewritePatternSet &patterns); + +void populateMKToTx81ConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr> createMKToTx81Pass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MK_TO_TX81_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h new file mode 100644 index 000000000..c9a8f51c0 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MK_TO_TX81_CONVERSION_PASSES_H +#define MK_TO_TX81_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // MK_TO_TX81_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td new file mode 100644 index 000000000..295fc05bd --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td @@ -0,0 +1,18 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MK_TO_TX81_CONVERSION_PASSES +#define MK_TO_TX81_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def MKToTx81 : Pass<"mk-to-tx81", "mlir::ModuleOp"> { + let summary = "Convert magic kernel operations into TsingMicro Tx81 operations"; + let constructor = "triton::createMKToTx81Pass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt new file mode 100644 index 000000000..fbc6e31df --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name Tx81MemrefToLLVM) +add_public_tablegen_target(Tx81MemrefToLLVMConversionPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h new file mode 100644 index 000000000..54079c976 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_TO_MK_CONVERSION_PASSES_H +#define MEMREF_TO_MK_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // MEMREF_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td new file mode 100644 index 000000000..b2a0d2c9b --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td @@ -0,0 +1,19 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_TO_MK_CONVERSION_PASSES +#define MEMREF_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def Tx81MemrefToLLVM : Pass<"tx81-memref-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert memref and bufferization operations into custom llvm function call."; + let constructor = "triton::createTx81MemrefToLLVMPass()"; + let options = []; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h new file mode 100644 index 000000000..96173716c --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h @@ -0,0 +1,43 @@ +//===------------------- Tx81MemrefToLLVM.h -------------------------*- C++ +//-*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering memref.copy, memref.alloc to mk.load, mk.alloc etc. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_MEMREF_TO_MK_H +#define ZTC_CONVERSION_MEMREF_TO_MK_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +// Declear spmPointer. +extern uint64_t spmPointer; + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +void populateTx81MemrefToLLVMCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTx81MemrefToLLVMConversionPatterns(RewritePatternSet &patterns, + LLVMTypeConverter &converter); + +std::unique_ptr> createTx81MemrefToLLVMPass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt new file mode 100644 index 000000000..626484155 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -0,0 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name Tx81ToLLVM) +add_public_tablegen_target(Tx81ToLLVMConversionPassIncGen) + +set(LLVM_TARGET_DEFINITIONS KernelArgBufferPass.td) +mlir_tablegen(KernelArgBufferPass.h.inc -gen-pass-decls --name KernelArgBufferPass) +add_public_tablegen_target(KernelArgBufferPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h new file mode 100644 index 000000000..f4de9dcaf --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h @@ -0,0 +1,35 @@ +//===- KernelArgBufferPass.h ----------------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms kernel function signatures by converting multiple +// arguments into a single void* buffer containing all the arguments. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_KERNEL_ARG_BUFFER_PASS_H +#define MLIR_KERNEL_ARG_BUFFER_PASS_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +class ModuleOp; +class Pass; + +namespace triton { +/// Creates a pass that transforms kernel functions by replacing multiple +/// arguments with a single void* buffer argument. +std::unique_ptr createKernelArgBufferPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // MLIR_KERNEL_ARG_BUFFER_PASS_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td new file mode 100644 index 000000000..a47c45d07 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td @@ -0,0 +1,32 @@ +//===- KernelArgBufferPass.td ---------------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_ARG_BUFFER_PASS +#define KERNEL_ARG_BUFFER_PASS + +include "mlir/Pass/PassBase.td" + +def KernelArgBufferPass : Pass<"kernel-arg-buffer", "ModuleOp"> { + let summary = "Convert kernel arguments to a single buffer argument"; + let description = [{ + This pass transforms kernel function signatures by converting multiple + arguments into a single void* buffer containing all the arguments. + + For example, a function like: + add_kernel(uint64_t* arg1, uint64_t* arg2, int64_t size, int gridX, int x) + + Will be converted to: + add_kernel(void* args) + + Where the args buffer contains pointers to arg1 and arg2, followed by the scalar + values size, gridX, and x. Each scalar value occupies 8 bytes in the buffer. + }]; + let constructor = "mlir::triton::createKernelArgBufferPass()"; + let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::func::FuncDialect"]; +} + +#endif // KERNEL_ARG_BUFFER_PASS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h new file mode 100644 index 000000000..f0f0138b8 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TX81_TO_LLVM_CONVERSION_PASSES_H +#define TX81_TO_LLVM_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TX81_TO_LLVM_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td new file mode 100644 index 000000000..2ed0159cc --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td @@ -0,0 +1,38 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TX81_TO_LLVM_CONVERSION_PASSES +#define TX81_TO_LLVM_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + + +def Tx81ToLLVM : Pass<"tx81-to-llvm", "ModuleOp"> { + let summary = "Convert Tx81 dialect to LLVM dialect"; + let description = [{ + This pass converts operations in the Tx81 dialect to the LLVM IR dialect. + + It handles the conversion of Tx81-specific operations like tx.rdma, tx.wdma, + tx.gemm etc to appropriate LLVM calls to the Tx81 runtime library. + + The pass also relies on existing conversion patterns for standard dialects + like arith, func, memref, etc. + }]; + + let constructor = "triton::createTx81ToLLVMPass()"; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect", + "tx::Tx81Dialect" + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h new file mode 100644 index 000000000..9d3ac7ffc --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h @@ -0,0 +1,33 @@ +//===------------------- Tx81ToLLVM.h -------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TX81_TO_LLVM_H +#define TRITON_CONVERSION_TX81_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +void populateTx81ToLLVMConversionPatterns(RewritePatternSet &patterns, + ConversionTarget &target, + LLVMTypeConverter &converter); + +std::unique_ptr> createTx81ToLLVMPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TX81_TO_LLVM_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt new file mode 100644 index 000000000..6b74f8677 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS Tx81Ops.td) +mlir_tablegen(Tx81Dialect.h.inc -gen-dialect-decls -dialect=tx) +mlir_tablegen(Tx81Dialect.cpp.inc -gen-dialect-defs -dialect=tx) +mlir_tablegen(Tx81Ops.h.inc -gen-op-decls) +mlir_tablegen(Tx81Ops.cpp.inc -gen-op-defs) + +mlir_tablegen(Tx81Enums.h.inc -gen-enum-decls) +mlir_tablegen(Tx81Enums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS Tx81Types.td) +mlir_tablegen(Tx81Types.h.inc -gen-typedef-decls) +mlir_tablegen(Tx81Types.cpp.inc -gen-typedef-defs) + +add_public_tablegen_target(Tx81TableGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td new file mode 100644 index 000000000..07ca7549f --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td @@ -0,0 +1,24 @@ +//===---------------------- Tx81AttrDefs.td -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_ATTR_DEFS +#define TSINGMICRO_TX81_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Round mode, aligned with RND_MODE in instr_def.h +def RoundModeAttr : I32EnumAttr<"RoundMode", "Round mode", [ + I32EnumAttrCase<"RND_NEAREST_EVEN", 0, "nearest">, + I32EnumAttrCase<"RND_ZERO", 1, "zero">, + I32EnumAttrCase<"RND_POS_INF", 2, "pos">, + I32EnumAttrCase<"RND_NEG_INF", 3, "neg">, + I32EnumAttrCase<"RND_STOCHASTIC", 4, "stochastic"> +]> { + let cppNamespace = "::mlir::tx"; +} + +#endif // TSINGMICRO_TX81_ATTR_DEFS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h new file mode 100644 index 000000000..955cbdb1e --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h @@ -0,0 +1,33 @@ +//===-------------------------- Tx81Dialect.h -----------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H +#define MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// TsingMicro Tx81 Operations +//===----------------------------------------------------------------------===// +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" + +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td new file mode 100644 index 000000000..172d2a6ee --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td @@ -0,0 +1,43 @@ +//===----------------------- Tx81Dialect.td -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_DIALECT +#define TSINGMICRO_TX81_DIALECT + +include "mlir/IR/OpBase.td" + +def Tx81Dialect : Dialect { + let name = "tx"; + + let cppNamespace = "::mlir::tx"; + + let summary = "The TsingMicro Tx81 IR in MLIR"; + + let description = [{ + TsingMicro Tx81 Dialect. + + Dependent Dialects: + * MK + * Memref + * Bufferization + }]; + + let dependentDialects = [ + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + // let hasConstantMaterializer = 1; + // let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "tsingmicro-tx81/Dialect/IR/Tx81Types.td" + +#endif // TSINGMICRO_TX81_DIALECT diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h new file mode 100644 index 000000000..dc27e2388 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h @@ -0,0 +1,26 @@ +//===-------------------------- Tx81Ops.h ---------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TSINGMICRO_TX81_IR_OPS_H +#define MLIR_DIALECT_TSINGMICRO_TX81_IR_OPS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" + +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td new file mode 100644 index 000000000..71899adf6 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -0,0 +1,864 @@ + +//===---------------------- Tx81Ops.td ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Definition of TsingMicro's Tx81 ML accelerator operations. +// +// Data format supported by Tx81 ML accelerator are: +// f16,fp16,tf32,fp32 +// +// For Tx81 accelerator unsupported data type, we can either convert it by +// using `TsmConvert`, or lower the operations to run on RISC-V controller +// instead. +// +// NOTE: CHANGING THE ARGUMENTS AND RETURNS OF ANY OPS RESULT IN THE CHANGE OF +// THEIR RUNTIME INTERFACE AND IMPLEMENTATION IN crt/Target/Tx81. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_OPS +#define TSINGMICRO_TX81_OPS + +include "tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td" +include "tsingmicro-tx81/Dialect/IR/Tx81Types.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/IR/OpBase.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +class Tx81Op traits = []> : + Op { +} + +def MemRefOrInt + : AnyTypeOf<[AnyMemRef, AnySignlessIntegerOrIndex], + "MemRef or Int as address type.", "::mlir::Type">; + +// ============================================================================= +// 4.8/4.9 DDR and SPM transfer ops +// ============================================================================= + +def RdmaOp : Tx81Op<"rdma", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { + + let summary = "Copy data from global memory DDR(dram) to per thread local SPM(sram)"; + + let description = [{ + Copy data from global memory DDR(dram) to per thread local SPM(sram). + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$shape, // NHWC shape + Variadic:$strides, // 3 dim strides + I32Attr:$fmt + ); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$target, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + + let results = (outs I64:$dst); // The dest address in SPM +} + +def WdmaOp : Tx81Op<"wdma", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { + let summary = "Copy data from per thread local SPM(sram) to global memory DDR(dram)"; + + let description = [{ + Copy data from per thread local SPM(sram) to global memory DDR(dram). + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$shape, // NHWC shape + Variadic:$strides, // 3 dim strides + I32Attr:$fmt + ); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$target, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + + let results = (outs I64:$dst); // The dest address in DDR +} + +// ============================================================================= +// 4.4~6 TsmConv, TsmDepthwiseConv, TsmBackwardConv +// ============================================================================= + +def ConvOp : Tx81Op<"conv", [Pure]> { + let summary = "Convolution engine intrinsic runtime API"; + + let description = [{ + A common convolution op for TsmConv, TsmDepthwiseConv, TsmBackwardConv. + This TsmConv is not a 1 to 1 map to TsingMicro's TsmConv intrinsic, it is + the wrap of all APIs related to TsmConv. This op wraps the following APIs: + TsmNewConv, TsmDeleteConv, AddInput, AddWeight, AddBias, AddOutput, + SetOpType, SetNegativeAxisScale, SetPositiveAxisScale, SetSparse, SetPsum, + SetPads, SetUnPads, SetKernelStrides, SetDilations, EnableRelu, + EnableLeakyRelu, DisableRelu, DisableLeakyRelu, SetQuant. + }]; + + let arguments = ( + ins + I64Attr:$op_type, // 0: conv, 1: depthwise conv, 2: backward conv, + // 3: gemm + MemRefOrInt:$src_activation, // Input activation addr in SPM + I32ArrayAttr:$src_dims, // dims of src activation in NHWC format + MemRefOrInt:$weight, // Input weight addr in SPM + I16Attr:$weight_dims, // dims of weight(conv kernel) in Kx, Ky, Sx, Sy + // Where K and S is short for size(K) and step(S) + BoolAttr:$en_bias, // Enable bias add + MemRefOrInt:$src_bias, // The address of bias in SPM + BoolAttr:$en_neg_scale, // Enable negative axis scale + MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM + BoolAttr:$en_pos_scale, // Enable positive axis scale + MemRefOrInt:$src_pos_scale, // The address of positive scale data in SPM + BoolAttr:$en_sparse, // Enable sparse + MemRefOrInt:$src_sparse, // The sparse matrix addr in SPM + BoolAttr:$en_psum, // Enable psum? TODO: Production sum? + MemRefOrInt:$src_psum, // psum addr in SPM? + I32ArrayAttr:$pads, // Pad in top, bottom, left, right order + I32ArrayAttr:$unpads, // Unpad in top, bottom, left, right order + I32ArrayAttr:$strides, // Kernel strids in Kx, Ky, Sx, Sy + I32ArrayAttr:$dilations, // dialation d0, d1 for conv/backwardconv + BoolAttr:$en_leaky_relu, // Enable LeakyRelu or normal Relu + I32ArrayAttr:$out_dims, // dims of output in NHWC format + I64Attr:$src_fmt, // Data format of src activation + I64Attr:$weight_fmt, // Data format of weight + I64Attr:$out_fmt // Data format of output + // The param of SetQuant() is unused + ); + + // Output matrix C addr in SPM + let results = (outs I64:$dst); +} + +// ============================================================================= +// 4.7. TsmGemm +// ============================================================================= + +def GemmOp : Tx81Op<"gemm", []> { + let summary = "Gemm engine intrinsic runtime API"; + + let description = [{ + This TsmGemm is not a 1 to 1 map to TsingMicro's TsmGemm intrinsic, it is + the wrap of all APIs related to TsmGemm. This op wraps the following APIs: + TsmNewGemm, TsmDeleteGemm, AddInput, ConfigMKN, AddOutput, SetPsum, + SetTransflag, SetQuant, ConfigBatch, EnableRelu, EnableLeakyRelu, + DisableRelu, DisableLeakyRelu, AddBias, SetNegativeAxisScale, + SetPositiveAxisScale. + }]; + + let arguments = ( + ins + MemRefOrInt:$src_a, // Input matrix A addr in SPM + MemRefOrInt:$src_b, // Input matrix B addr in SPM + MemRefOrInt:$src_bias, // The address of bias in SPM + // Output and initial zeroes buffer + // FIXME: Whether need add side effect to source operands? + Arg:$dst, + I32ArrayAttr:$dims, // The dimensions of M, K, N + BoolAttr:$en_psum, // Enable psum? TODO: Production sum? + MemRefOrInt:$psum_addr, // The address of psum in SPM, TODO: psum? + BoolAttr:$trans_src_a, // Should matrix A be transposed + BoolAttr:$trans_src_b, // Should matrix B be transposed + I32Attr:$batch_src_a, // The batch of matrix A + I32Attr:$batch_src_b, // The batch of matrix B + I32Attr:$relu_mode, // Enable LeakyRelu or normal Relu or none + BoolAttr:$en_bias, // Enable bias add + BoolAttr:$en_neg_scale, // Enable negative axis scale + MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM + BoolAttr:$en_pos_scale, // Enable positive axis scale + MemRefOrInt:$src_pos_scale, // The address of positive scale data in SPM + I32Attr:$src_fmt, // Input matrix data format + I32Attr:$dst_fmt // Output matrix data format + // The param of SetQuant() is unused + ); + + // Output matrix C addr in SPM + let results = (outs Variadic:$output); +} + +// ============================================================================= +// 4.10. TsmArith +// ============================================================================= + +class UnaryOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, // Input vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AbsVVOp : UnaryOp<"absvv"> { + let summary = "Absolute value of input vector"; +} +def SqrtVVOp : UnaryOp<"sqrtvv", [Pure, Elementwise]> {} +def RsqrtVVOp : UnaryOp<"rsqrtvv", [Pure, Elementwise]> {} +def NegVVOp : UnaryOp<"negvv", [Pure, Elementwise]> {} +def RecipVVOp : Tx81Op<"recipvv", [Pure, Elementwise]> {} +def SquareVVOp : Tx81Op<"squarevv", [Pure, Elementwise]> {} + +class BinaryVVOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$rnd_mode, // round mode + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AddVVOp : BinaryVVOp<"addvv"> { + let summary = "Add two vectors element-wise"; +} +def SubVVOp : BinaryVVOp<"subvv">; +def MulVVOp : BinaryVVOp<"mulvv">; +def DivVVOp : BinaryVVOp<"divvv">; +def MaxVVOp : BinaryVVOp<"maxvv">; +def MinVVOp : BinaryVVOp<"minvv">; + +class BinaryVSOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + I32:$value, // Const value + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$rnd_mode, // round mode + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AddVSOp : BinaryVSOp<"addvs"> { + let summary = "Add input vector and constant value"; +} +def SubVSOp : BinaryVSOp<"subvs">; +def MulVSOp : BinaryVSOp<"mulvs">; +def DivVSOp : BinaryVSOp<"divvs">; + +// ... + +// ============================================================================= +// 4.11. TsmRelation +// ============================================================================= + +class BoolRelationVVOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def BoolEqualVV : BoolRelationVVOp<"boolequalvv"> { + let summary = "compare two input value, if equal, return true"; +} + +def BoolUnEqualVV : BoolRelationVVOp<"boolunequalvv"> { + let summary = "compare two input value, if unequal, return true"; +} + +def BoolGreaterEqualVV : BoolRelationVVOp<"boolgreatrequalvv"> { + let summary = "compare two input value, if src0 >= src1, return true"; +} + +def BoolGreaterVV : BoolRelationVVOp<"boolgreatervv"> { + let summary = "compare two input value, if src0 > src1, return true"; +} + +def BoolLessEqualVV : BoolRelationVVOp<"boollessequalvv"> { + let summary = "compare two input value, if src0 <= src1, return true"; +} + +def BoolLessThenVV : BoolRelationVVOp<"boollessthenvv"> { + let summary = "compare two input value, if src0 < src1, return true"; +} + +// ... +// ============================================================================= +// 4.12. TsmLogic +// ============================================================================= + +class BinaryLogicVVOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AndVV : BinaryLogicVVOp<"andvv"> { + let summary = "And operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +def OrVV : BinaryLogicVVOp<"orvv"> { + let summary = "OR operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +def XorVV : BinaryLogicVVOp<"xorvv"> { + let summary = "XOR operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +// ============================================================================= +// 4.13. TsmTranscendental +// ============================================================================= + +def Log2Op : UnaryOp<"log2", []> { + let summary = "Logarithm based 2"; +} +def LnOp : UnaryOp<"ln", []> { + let summary = "Logarithm based e"; +} +def Pow2Op : UnaryOp<"pow2", []> { + let summary = "2 ** x"; +} +def ExpOp : UnaryOp<"exp", []> { + let summary = "Exponential with high precision"; +} +def ExplpOp : UnaryOp<"explp", []> { + let summary = "Exponential with low precision"; +} +def SinOp : UnaryOp<"sin", []> { + let summary = "Sine"; +} +def CosOp : UnaryOp<"cos", []> { + let summary = "Cosine"; +} + +// ============================================================================= +// 4.13. TsmActivation +// ============================================================================= + +class ActivationOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$src, // Input vector address + UI32Attr:$elem_count, // Number of input elements + UI16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Tanh : ActivationOp<"tanh", []> { + let summary = "Hyperbolic tangent"; +} +def Sigmoid : ActivationOp<"sigmoid", []> { + let summary = "Logistic sigmoid"; +} +def Relu : ActivationOp<"relu", []> { + let summary = "Rectified linear unit"; +} +def Satrelu : ActivationOp<"satrelu", []> { + let summary = "Saturated ReLU"; +} +def Leakyrelu : ActivationOp<"leakyrelu", []> { + let summary = "Leaky rectified linear unit"; +} +def Softplus : ActivationOp<"softplus", []> { + let summary = "Smooth approximation of ReLU"; +} + +// ============================================================================= +// 4.15. TsmReduce +// ============================================================================= + +class Reduce : Tx81Op { + let summary = "Reduction engine intrinsic runtime API"; + + let description = [{ + Includes ReduceSum, ReduceAvg, ReduceMin and ReduceMax interfaces. + Mapping between `dim` and NCHW: + Reduction on C: dim=0 + Reduction on W: dim=1 + Reduction on H: dim=2 + Reduction on HW: dim=4 + }]; + + let arguments = ( + ins + AnyType:$src, // Input tensor address in SPM + Arg:$dst, // Output tensor address in SPM + UI32Attr:$dim, // Which dimension to be reduced + I64ArrayAttr:$shape, // The shape info of src + I16Attr:$fmt // The data format of src & dst + ); + + // Output tensor address in SPM + let results = (outs Variadic); +} + +def ReduceSumOp : Reduce<"reduce_sum">; +def ReduceAvgOp : Reduce<"reduce_avg">; +def ReduceMaxOp : Reduce<"reduce_max">; +def ReduceMinOp : Reduce<"reduce_min">; + +// ============================================================================= +// 4.15. TsmMaskDataMove +// ============================================================================= + +def MaskMoveOp : Tx81Op<"mask_move", []> { + let summary = "Mask data move engine intrinsic runtime API"; + + let description = [{ When mask is 1, extract the data from src and write it to dst. +When mask=0, the corresponding elements of dst remain unchanged. + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in SPM + // The target address in SPM + Arg:$target, + AnySignlessIntegerOrIndex:$elem_count, // Number of elements to be copied + I32ArrayAttr:$mask, // 3 dim masks + I32Attr:$fmt + ); + + // The dst address is not used, use target in arguments instead. + let results = (outs Variadic:$dst); +} + +// ============================================================================= +// 4.19. TsmConvert instructions +// ============================================================================= + +class ZeroPointConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$src, + UI32Attr:$zp, + UI32Attr:$elem_count + ); + let results = (outs UI64:$dst); +} + +def INT8ToFP16Op : ZeroPointConvertOp<"int8_fp16", []> { + let summary = "Data format from int8 to fp16"; +} +def INT8ToBF16Op : ZeroPointConvertOp<"int8_bf16", []> { + let summary = "Data format from int8 to bf16"; +} +def INT8ToFP32Op : ZeroPointConvertOp<"int8_fp32", []> { + let summary = "Data format from int8 to fp32"; +} +def INT8ToTF32Op : ZeroPointConvertOp<"int8_tf32", []> { + let summary = "Data format from int8 to tf32"; +} + +class RoundConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, + Arg:$output, + AnySignlessIntegerOrIndex:$elem_count, + I16Attr:$rnd_mode + ); + let results = (outs I64:$dst); +} + +def INT16ToBF16Op : RoundConvertOp<"int16_bf16", []> { + let summary = "Data format from int16 to bf16"; +} +def INT16ToFP32Op : RoundConvertOp<"int16_fp32", []> { + let summary = "Data format from int16 to fp32"; +} +def INT16ToTF32Op : RoundConvertOp<"int16_tf32", []> { + let summary = "Data format from int16 to tf32"; +} +def INT32ToFP16Op : RoundConvertOp<"int32_fp16", []> { + let summary = "Data format from int32 to fp16"; +} +def INT32ToBF16Op : RoundConvertOp<"int32_bf16", []> { + let summary = "Data format from int32 to bf16"; +} +def INT32ToFP32Op : RoundConvertOp<"int32_fp32", []> { + let summary = "Data format from int32 to fp32"; +} +def INT32ToTF32Op : RoundConvertOp<"int32_tf32", []> { + let summary = "Data format from int32 to tf32"; +} +def BF16ToINT16Op : RoundConvertOp<"bf16_int16", []> { + let summary = "Data format from bf16 to int16"; +} +def BF16ToINT32Op : RoundConvertOp<"bf16_int32", []> { + let summary = "Data format from bf16 to int32"; +} +def FP16ToINT8Op : RoundConvertOp<"fp16_int8", []> { + let summary = "Data format from fp16 to int8"; +} +def FP16ToINT16Op : RoundConvertOp<"fp16_int16", []> { + let summary = "Data format from fp16 to int16"; +} +def FP16ToINT32Op : RoundConvertOp<"fp16_int32", []> { + let summary = "Data format from fp16 to int32"; +} +def FP16ToBF16Op : RoundConvertOp<"fp16_bf16", []> { + let summary = "Data format from fp16 to bf16"; +} +def FP32ToINT8Op : RoundConvertOp<"fp32_int8", []> { + let summary = "Data format from fp32 to int8"; +} +def FP32ToINT16Op : RoundConvertOp<"fp32_int16", []> { + let summary = "Data format from fp32 to int16"; +} +def FP32ToINT32Op : RoundConvertOp<"fp32_int32", []> { + let summary = "Data format from fp32 to int32"; +} +def FP32ToFP16Op : RoundConvertOp<"fp32_fp16", []> { + let summary = "Data format from fp32 to fp16"; +} +def FP32ToBF16Op : RoundConvertOp<"fp32_bf16", []> { + let summary = "Data format from fp32 to bf16"; +} +def FP32ToTF32Op : RoundConvertOp<"fp32_tf32", []> { + let summary = "Data format from fp32 to tf32"; +} +def TF32ToINT8Op : RoundConvertOp<"tf32_int8", []> { + let summary = "Data format from tf32 to int8"; +} +def TF32ToINT16Op : RoundConvertOp<"tf32_int16", []> { + let summary = "Data format from tf32 to int16"; +} +def TF32ToINT32Op : RoundConvertOp<"tf32_int32", []> { + let summary = "Data format from tf32 to int32"; +} +def TF32ToFP32Op : RoundConvertOp<"tf32_fp32", []> { + let summary = "Data format from tf32 to fp32"; +} + +class NormalConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, + Arg:$output, + AnySignlessIntegerOrIndex:$elem_count + ); + let results = (outs I64:$dst); +} + +def INT16ToFP16Op : NormalConvertOp<"int16_fp16", []> { + let summary = "Data format from int16 to fp16"; +} +def BF16ToINT8Op : NormalConvertOp<"bf16_int8", []> { + let summary = "Data format from bf16 to int8"; +} +def BF16ToFP16Op : NormalConvertOp<"bf16_fp16", []> { + let summary = "Data format from bf16 to fp16"; +} +def BF16ToFP32Op : NormalConvertOp<"bf16_fp32", []> { + let summary = "Data format from bf16 to fp32"; +} +def BF16ToTF32Op : NormalConvertOp<"bf16_tf32", []> { + let summary = "Data format from bf16 to tf32"; +} +def FP16ToFP32Op : NormalConvertOp<"fp16_fp32", []> { + let summary = "Data format from fp16 to fp32"; +} +def FP16ToTF32Op : NormalConvertOp<"fp16_tf32", []> { + let summary = "Data format from fp16 to tf32"; +} +def TF32ToFP16Op : NormalConvertOp<"tf32_fp16", []> { + let summary = "Data format from tf32 to fp16"; +} +def TF32ToBF16Op : NormalConvertOp<"tf32_bf16", []> { + let summary = "Data format from tf32 to bf16"; +} + +// ============================================================================= +// 4.20. TsmPeripheral instructions +// ============================================================================= + +def CountOp : Tx81Op<"count", [Pure]> { + let summary = "Count the non-zero elements from given tensor"; + + let arguments = ( + ins + MemRefOrInt:$src, // Input tensor address in SPM + I32Attr:$elem_count, // TODO: Ask TsingMicro for explain. + //I64Attr:$p_wb_data0, // TODO: Ask TsingMicro for explain. + //I64Attr:$p_wb_data1, // TODO: Ask TsingMicro for explain. + I16Attr:$fmt + ); + + // The output tensor address in SPM + let results = (outs MemRefOrInt:$dst); +} + +def MemsetOp : Tx81Op<"memset", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { + let summary = "Write given `value` to range of address on SPM(sram)"; + + let arguments = ( + ins + MemRefOrInt:$src, // SPM address to be memset + I32:$value, // Value to be written + Variadic:$shape, // NHWC shape + Variadic:$strides, // 3 dim strides + I16Attr:$fmt + ); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$value, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + + // The address updated by memset in SPM + let results = (outs MemRefOrInt:$dst); +} + +def Bit2FpOp : Tx81Op<"bit2fp", []> { + let summary = "Convert a vector of the bitwise into the fp vector"; + + let arguments = (ins + UI64:$src, // Input tensor + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def ArgMaxOp : Tx81Op<"argmax", []> { + let summary = "Return a max value inner a vector and its corresponding index"; + + let arguments = (ins + UI64:$src, // Input vector + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + + let results = (outs + AnyType:$max, // Max value inner a vector + UI64:$index // Corresponding index + ); +} + +def ArgMinOp : Tx81Op<"argmin", []> { + let summary = "Return a min value inner a vector and its corresponding index"; + + let arguments = (ins + UI64:$src, // Input vector + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + + let results = (outs + AnyType:$min, // Min value inner a vector + UI64:$index // Corresponding index + ); +} + +def BilinearOp : Tx81Op<"bilinear", []> { + let summary = "Bilinear interpolation"; + + let arguments = (ins + UI64:$src, // Input tensor with the NHWC format + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + F32:$scale_w, // Input tensor "w" divided by output tensor "w" + F32:$scale_h, // Input tensor "h" divided by output tensor "h" + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Lut16Op : Tx81Op<"lut16", []> { + let summary = "16-bit lookup table"; + + let arguments = (ins + MemRefOrInt:$src, // Vector offset with respect to LUT + UI64:$lut16, + I32Attr:$src_elem_count, // Number of elements in vector offset + I32Attr:$lut_elem_count // Number of elements in LUT + ); + + let results = (outs UI64:$dst); +} + +def Lut32Op : Tx81Op<"lut32", []> { + let summary = "32-bit lookup table"; + + let arguments = (ins + MemRefOrInt:$src, // Vector offset with respect to LUT + UI64:$lut32, + I32Attr:$src_elem_count, // Number of elements in vector offset + I32Attr:$lut_elem_count // Number of elements in LUT + ); + + let results = (outs UI64:$dst); +} + +def RandGenOp : Tx81Op<"randgen", []> { + let summary = "Generate random numbers using two 64-bit seeds"; + + let arguments = (ins + UI64:$src0, // The first random seed + UI64:$src1, // The second random seed + UI64:$dst0, // Store the first random seed + UI64:$dst1, // Store the second random seed + UI64:$dst2, // Random value + I32Attr:$elem_num, // Number of random values + I16Attr:$fmt // The date format of random value + ); +} + + +// +// 4.21. TsmDataMove +// + +class TransformOp traits = []> : + Tx81Op { + let arguments = (ins + UI64:$src, // Input matrix or tensor address + I32ArrayAttr:$src_shape, // Input shape + I32ArrayAttr:$dst_shape, // Output shape + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Mirror : TransformOp<"mirror", []> { + let summary = "Horizontal mirror to a matrix"; +} +def Transpose : TransformOp<"transpose", []> { + let summary = "Transpose a matrix"; +} +def Rotate90 : TransformOp<"rotate90", []> { + let summary = "Rotate a matrix 90 degree clockwise"; +} +def Rotate180 : TransformOp<"rotate180", []> { + let summary = "Rotate a matrix 180 degree clockwise"; +} +def Rotate270 : TransformOp<"rotate270", []> { + let summary = "Rotate a matrix 270 degree clockwise"; +} +def Nchw2nhwc : TransformOp<"nchw2nhwc", []> { + let summary = "Tranform a tensor from nchw to nhwc"; +} +def Nhwc2nchw : TransformOp<"nhwc2nchw", []> { + let summary = "Tranform a tensor from nhwc to nchw"; +} +def TensorNorm : TransformOp<"tensornorm", []> { + let summary = "Make continuous tensor align std format in ch direction"; +} + +def Concat : Tx81Op<"concat", []> { + let summary = "Concatenation based on the dim"; + + let arguments = (ins + UI64:$src1, // The first input tensor + I32ArrayAttr:$src1_shape, // The first input tensor shape + UI64:$src2, // The second input tensor + I32ArrayAttr:$src2_shape, // The second input tensor shape + I32ArrayAttr:$dst_shape, // Ouput tensor shape + I16Attr:$dim, // Represent the concat direction, such as: + // 0 is channel, 1 is width, and 2 is height + I16Attr:$fmt // The data format of input & output tensor + ); + let results = (outs UI64:$dst); +} + +def Pad : Tx81Op<"pad", []> { + let summary = "Tensor padding"; + + let arguments = (ins + UI64:$src, // Input tensor + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + I16Attr:$pad, // Padding mode: top, bottom, left, and right + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Img2col : Tx81Op<"img2col", []> { + let summary = "Transform a feature-map tensor into a matrix"; + + let arguments = (ins + UI64:$src, // Input tensor + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + I32Attr:$src_elem_num, // Number of elements in input tensor + I32Attr:$dst_elem_num, // Number of elements in output tensor + I32ArrayAttr:$swr, // Horizontal stride of convolution + I32ArrayAttr:$pdr, // Vertical stride of convolution + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def GatherScatter : Tx81Op<"gatherscatter", []> { + let summary = "Transfer data in strides and iterations"; + + let arguments = (ins + UI64:$src, // Input tensor + I32Attr:$size, // Transfer data size in bytes + I32ArrayAttr:$src_strides, // 3 dim strides for input + I32ArrayAttr:$src_iterations, // 3 dim iterations for input + I32ArrayAttr:$dst_strides, // 3 dim strides for output + I32ArrayAttr:$dst_iterations // 3 dim iterations for output + ); + let results = (outs UI64:$dst); +} + +#endif // TSINGMICRO_TX81_OPS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td new file mode 100644 index 000000000..cce13dfe9 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td @@ -0,0 +1,107 @@ +//===-------------------------- Tx81Types.td ------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// TODO: Update this file to define the customized type used by Tx81 dialect, +// it is now copy-and-pasted from MagicKernelTypes.td. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_TYPES_TD +#define TSINGMICRO_TX81_TYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.td" + +// +// Types +// +class MKTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def MKFloat : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def MKFloatTensor : RankedTensorOf<[MKFloat]>; +def MKFloatLike : AnyTypeOf<[MKFloat, MKFloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def MKBoolTensor : RankedTensorOf<[I1]>; +def MKBoolLike : AnyTypeOf<[I1, MKBoolTensor]>; + +// Integer Type +def I4 : I<4>; +def MKInt : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def MKIntTensor : RankedTensorOf<[MKInt]>; +def MKIntLike : AnyTypeOf<[MKInt, MKIntTensor]>; + +// I32 Type +// MKI32 -> I32 +// MKI32Tensor -> I32Tensor +def MKI32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// MKI64 -> I64 +// MKI64Tensor -> I64Tensor +def MKI64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class MKPtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `MKPtrOf`) +def MKPtrType : MKTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def MKPtr : MKPtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def MKPtrTensor : RankedTensorOf<[MKPtr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def MKPtrLike : AnyTypeOf<[MKPtr, MKPtrTensor]>; + +// Tensor Type +def MKFpIntTensor : RankedTensorOf<[MKFloat, MKInt]>; +def MKTensor : RankedTensorOf<[MKFloat, MKInt, MKPtr]>; + +// Pointer Type to Tensor Type: `ptr>` +def MKTensorPtr : MKPtrOf<[MKTensor]>; + +// Any Type in Magic Kernel IR +def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; + +#endif // TSINGMICRO_TX81_TYPES_TD diff --git a/third_party/tsingmicro/lib/Analysis/CMakeLists.txt b/third_party/tsingmicro/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..643c6834f --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(ZTCAnalysis + MaskAnalysis.cpp + OpFoldResultUtils.cpp + PtrAnalysis.cpp + UseAnalysis.cpp + + DEPENDS + TritonAnalysis + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis +) diff --git a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp new file mode 100644 index 000000000..ede7a7fd5 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp @@ -0,0 +1,559 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include + +namespace mlir { + +namespace triton { + +LogicalResult MaskState::parse(Value operand, const Location loc, + OpBuilder &builder) { + if (auto op = operand.getDefiningOp()) { + return this->parseConstant(op, loc, builder); + } else if (isa(operand.getType())) { + return this->parseIntScalar(operand, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAdd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAnd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseCmp(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseMakeRange(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseBroadcast(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseSplat(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseExpandDims(op, loc, builder); + } else if (!operand.getDefiningOp()) { + return this->parseLoopIterArg(operand, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseExtSI(op, loc, builder); + } else { + return failure(); + } +} + +tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + const Location loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, + dims, strides); + + return builder.create(loc, dstType, source, offsets, + dims, strides); +} + +memref::SubViewOp MaskState::getSubview(Value source, const Location loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + + return builder.create(loc, cast(dstType), + source, offsets, dims, strides); +} + +static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return b.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +// Assume block1 wraps around and the remainder is block2. +// +// |----------------------| +// | | | +// | block2 | block1 | +// | | | +// |----------------------| +// +// Once we copy the chunks in order, the end result is block1 followed by +// block2. +// +// buffer_tmp: +// +// |----------------------| +// | | | +// | block1 | block2 | +// | | | +// |----------------------| +// +// Assume we have the following subview: +// +// +++++++++++++++++------- +// + + | +// + subview + | +// + + | +// +++++++++++++++++------- +// +// If we simply take the subview of `buffer_tmp`, this requires an extra +// buffer to just hold the temporary result. +// +// So we can subview into block1 and block2 directly. There are 2 cases: +// + subview only spans block1 +// + subview spans both block1 and block2, creating sv1 and sv2 (illustrated +// below for case when we wrap around side-by-side) +// +// |----------------------------------------| +// | | +// | col2 col1 | +// |++++++--------| |+++++++++++++++ +// | sv2 + block2 | | block1 & sv1 + +// |++++++--------| |+++++++++++++++ +// | | +// |----------------------------------------| +// +// For simplicity, assume we only wrap around side-by-side. +// +// Let (row, col1) and (row, col2) be the dimensions of block1 and block2, +// respectively. +// +// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be +// the dimensions of the full subview, sv1, and sv2, respectively. +// +// + colView1 = min(colFull, col1) +// + colView2 = colFull - colView1 +// + rowView1 = rowView2 = row = rowFull +std::pair +MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder); + OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder); + + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, + {subviewRowFull, subviewCol1}, strides); + auto sv2 = createSubview(block2, loc, builder, offsets, + {subviewRowFull, subviewCol2}, strides); + + return {sv1, sv2}; +} + +std::pair +MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder); + OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder); + + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, + {subviewRow1, subviewColFull}, strides); + auto sv2 = createSubview(block2, loc, builder, offsets, + {subviewRow2, subviewColFull}, strides); + return {sv1, sv2}; +} + +LogicalResult MaskState::addStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder) { + start = addOFRs(state.start, scalar, loc, builder); + end = addOFRs(state.end, scalar, loc, builder); + dims = state.dims; + return success(); +} + +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + else + return addStateScalar(lhsState, rhsState.scalar, loc, builder); +} + +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder)); + } + return success(); +} + +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, + const Location loc, OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "All elements must share a single integer constant value"); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto op = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + this->scalar = op.getValue(); + } else { + auto value = cast(constOp.getValue()).getInt(); + this->scalar = builder.getIndexAttr(value); + } + + return success(); +} + +LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto castOp = + builder.create(loc, builder.getIndexType(), scalar); + this->scalar = castOp.getResult(); + return success(); +} + +void MaskState::dump() const { + llvm::dbgs() << "start: " << start << "\n"; + llvm::dbgs() << "end: " << end << "\n"; + llvm::dbgs() << "scalar: " << scalar << "\n"; + llvm::dbgs() << "useUnsafeMask: " << useUnsafeMask << "\n"; + llvm::dbgs() << "dims: "; + for (auto dim : dims) + llvm::dbgs() << "\t" << dim << "\n"; + llvm::dbgs() << "\n"; +} + +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) + return failure(); + + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) + return failure(); + + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseExtSI(arith::ExtSIOp op, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + return parse(op.getIn(), loc, builder); +} + +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt && + cmpOp.getPredicate() != arith::CmpIPredicate::ult && + cmpOp.getPredicate() != arith::CmpIPredicate::sge) { + InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi"; + return failure(); + } + + MaskState lhsState; + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) + return failure(); + + // We only support sge against 0 for lower bounds. Dims already has an + // implicit assumption that the lower bound is 0, so if we see this, assume + // the comparison evaluates to true. + if (cmpOp.getPredicate() == arith::CmpIPredicate::sge && + !(rhsState.scalar && hasConstZero(rhsState.scalar))) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with rhs not equal to 0"; + return failure(); + } + + int32_t cmpDim = lhsState.scalar && rhsState.scalar ? 0 : -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto dimIntAttr = getIntAttr(lhsState.dims[i]); + if (!dimIntAttr || dimIntAttr.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + OpFoldResult newDim; + if (lhsState.scalar) { + assert(rhsState.scalar && "Unexpected case where rhs is not a scalar"); + // If both lhs and rhs are scalars, we can't just derive the dimension of + // the mask as the minimum value: lhs/rhs could be 0 and then we don't + // load/store anything. + // + // Instead treat the comparison as a scalar that determines if anything + // should be loaded/stored by inserting a comparison + select: + // dim = lhs < rhs ? lhs.dim : 0 + newDim = compareOFRs(lhsState.scalar, rhsState.scalar, cmpOp.getPredicate(), + lhsState.dims[cmpDim], builder.getIndexAttr(0), loc, + builder); + } else if (cmpOp.getPredicate() == arith::CmpIPredicate::slt || + cmpOp.getPredicate() == arith::CmpIPredicate::ult) { + // Important: + // In the case where the values we are loading are entirely masked off like + // the following: + // + // ---|-------|-----------| + // ^ ^ ^ + // scalar start end + // + // newEnd = min(end, scalar) = scalar + // Now scalar < start, so simply doing dim = newEnd - start is incorrect. + // + // The correct formula is to optionally move `newDim` back to `start` using + // max(newEnd, start). + auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); + newEnd = maxOFRs(newEnd, lhsState.start, loc, builder); + newDim = subOFRs(newEnd, lhsState.start, loc, builder); + } else { + assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && + rhsState.scalar && hasConstZero(rhsState.scalar)); + newDim = lhsState.dims[cmpDim]; + } + + for (int32_t i = 0; i < lhsState.getRank(); i++) { + if (i == cmpDim) + this->dims.push_back(newDim); + else + this->dims.push_back(lhsState.dims[i]); + } + + return success(); +} + +LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder) { + assert(!v.getDefiningOp()); + + auto forOp = llvm::dyn_cast(v.getParentRegion()->getParentOp()); + + if (!forOp) { + return failure(); + } + + // TODO: This implementation does not work with nested loops + if (forOp->getParentOfType()) { + return failure(); + } + + auto it = llvm::find(forOp.getRegionIterArgs(), v); + if (it == forOp.getRegionIterArgs().end()) { + return failure(); + } + + auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it); + auto initArg = forOp.getInitArgs()[argIndex]; + if (auto getStateOp = initArg.getDefiningOp()) { + auto tritonValue = getStateOp->getOperand(0); + MaskState lhsState; + if (failed(lhsState.parse(tritonValue, loc, builder))) { + return failure(); + } + + // This is a bit of a hack!! + // + // The offsets and dimensions of a MaskState can now depend on a loop's + // iter-arg. + // + // Because the PtrAnalysis's pre-pass already sets up the offsets, + // we can create a new MaskState for each loop iteration by adding the + // original MaskState with the current iter-arg, which is at `argIndex + + // 1`. + // + // This will not work for nested loop scenarios, which would need a + // more robust implementation. + if (failed(this->addStateScalar( + lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) { + return failure(); + } + + return success(); + } + + return failure(); +} + +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitError(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + + return success(); +} + +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) + return failure(); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + this->dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } + + return success(); +} + +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (!isa(src.getType())) { + InFlightDiagnostic diag = + emitError(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + for (auto s : dstShape) + this->dims.push_back(builder.getIndexAttr(s)); + + return success(); +} + +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) + return failure(); + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp new file mode 100644 index 000000000..aa4904bdf --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp @@ -0,0 +1,289 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +std::optional getIntAttr(const OpFoldResult ofr) { + if (isa(ofr) && isa(cast(ofr))) + return dyn_cast(cast(ofr)).getInt(); + + return std::nullopt; +} + +bool hasConstZero(const OpFoldResult ofr) { + auto intAttr = getIntAttr(ofr); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + auto val = dyn_cast(ofr); + assert(val); + auto constOp = val.getDefiningOp(); + if (!constOp) + return false; + + intAttr = getIntAttr(constOp.getValue()); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + return false; +} + +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, + OpBuilder &b) { + if (Value val = dyn_cast(ofr)) { + assert(val.getType().isIntOrIndex()); + if (!val.getType().isIndex()) { + val = b.create(loc, b.getIndexType(), val); + } + return val; + } + + auto intVal = getIntAttr(ofr); + if (intVal.has_value()) { + return b.create(loc, b.getIndexAttr(intVal.value())); + } + llvm_unreachable("Unexpected OpFoldResult state"); + return nullptr; +} + +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b) { + return llvm::to_vector<4>( + llvm::map_range(ofrs, [&](OpFoldResult ofr) -> Value { + return ofrToIndexValue(ofr, loc, b); + })); +} + +OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + if (!rhsIntAttr && lhsIntAttr && lhsIntAttr.value() == 0) + return rhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(lhsIntAttr.value() + rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(lhsIntAttr.value() - rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto sumOp = b.create(loc, lhsValue, rhsValue); + return sumOp.getResult(); +} + +OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + + auto rhsIsConst = false; + // if rhs is not a const, use max value since min is used to represent + // dynamic size or stride + auto rhsConstValue = std::numeric_limits::max(); + auto rhsOp = rhs.getDefiningOp(); + if (rhsOp) { + rhsIsConst = true; + rhsConstValue = cast(rhsOp.getValue()).getInt(); + } + + // shortcuts for special cases + if (lhsIntAttr) { + if (lhsIntAttr.value() == 0) + return lhs; + if (lhsIntAttr.value() == 1) + return rhs; + } + if (rhsIsConst) { + if (rhsConstValue == 0) + return rhsOp.getResult(); + if (rhsConstValue == 1) + return lhs; + } + + // 0. both lhs and rhs are constants + if (lhsIntAttr && rhsIsConst) + return b.getIndexAttr(lhsIntAttr.value() * rhsConstValue); + + // 1. if lhs is constant but rhs is not + if (lhsIntAttr && !rhsIsConst) { + auto lhsConstOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + auto mulOp = b.create(loc, lhsConstOp.getResult(), rhs); + return mulOp.getResult(); + } + + // 2. if lhs is not constant + assert(!lhsIntAttr); + auto mulOp = b.create(loc, cast(lhs), rhs); + return mulOp.getResult(); +} + +OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(std::min(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto minOp = b.create(loc, lhsValue, rhsValue); + return minOp.getResult(); +} + +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(std::max(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto maxOp = b.create(loc, lhsValue, rhsValue); + return maxOp.getResult(); +} + +OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const arith::CmpIPredicate pred, + const OpFoldResult trueOFR, + const OpFoldResult falseOFR, const Location loc, + OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return the result directly + if (lhsIntAttr && rhsIntAttr) { + switch (pred) { + case arith::CmpIPredicate::eq: + return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::ne: + return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR; + default: + llvm_unreachable("Unsupported predicate"); + } + } + + auto lhsValue = ofrToIndexValue(lhs, loc, b); + auto rhsValue = ofrToIndexValue(rhs, loc, b); + auto trueValue = ofrToIndexValue(trueOFR, loc, b); + auto falseValue = ofrToIndexValue(falseOFR, loc, b); + + auto cmpOp = b.create(loc, pred, lhsValue, rhsValue); + auto selectOp = b.create(loc, cmpOp, trueValue, falseValue); + return selectOp.getResult(); +} +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp new file mode 100644 index 000000000..00715a9d3 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp @@ -0,0 +1,1375 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/PtrAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "triton-ptr-analysis" + +namespace mlir { + +namespace triton { + +static void assertValidUnrealizedCast(UnrealizedConversionCastOp op) { + assert(op && op->hasAttr(ModuloState::WraparoundAttr) && + op.getInputs().size() == 3 && + op.getInputs()[0].getDefiningOp() && + op.getInputs()[1].getDefiningOp() && + op.getInputs()[2].getDefiningOp()); +} + +MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape, + bool useDynamicStrides) const { + + SmallVector staticStrides; + if (useDynamicStrides) { + staticStrides.append(strides.size(), ShapedType::kDynamic); + } else { + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + } + + auto elementType = cast(source.getType()).getElementType(); + auto layout = + StridedLayoutAttr::get(source.getContext(), offset, staticStrides); + + return MemRefType::get(resultShape, elementType, layout); +} + +OpFoldResult +PtrState::accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult targetOffset = rewriter.getIndexAttr(0); + for (auto o : offsets) { + targetOffset = addOFRs(targetOffset, o, loc, rewriter); + } + return targetOffset; +} + +int64_t PtrState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + modulos.size() == offsets.size()); + return offsets.size(); +} + +bool PtrState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +bool PtrState::hasModulo() const { + return llvm::any_of(modulos, [](auto mod) { return mod.has_value(); }); +} + +void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, + Location loc, ConversionPatternRewriter &rewriter) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + // at most one of lhs and rhs should have valid source, since otherwise we + // will be losing information + assert(!(lhsState.source && rhsState.source)); + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + rewriter.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.sizes.size(); i++) { + auto newOffset = + addOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, rewriter); + offsets.push_back(newOffset); + + auto newStride = + addOFRs(lhsState.strides[i], rhsState.strides[i], loc, rewriter); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + + assert(!lhsState.hasModulo() || + !rhsState.hasModulo() && "AddPtr where both lhs and rhs containing " + "modulo operators not supported"); + + modulos.push_back(lhsState.modulos[i].has_value() ? lhsState.modulos[i] + : rhsState.modulos[i]); + } +} + +void PtrState::mulState(const PtrState &lhsState, const PtrState &rhsState, + const Location loc, + ConversionPatternRewriter &rewriter) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + assert(!(lhsState.source && rhsState.source)); + + assert((lhsState.scalar || rhsState.scalar) && + !(lhsState.scalar && rhsState.scalar) && + "currently does not support both tensors are effectively non-scalar"); + + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, rewriter); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, rewriter); + offsets.push_back(newOffset); + strides.push_back(newStride); + sizes.push_back(lhs->sizes[i]); + } + + assert(llvm::all_of(rhsState.modulos, + [](auto rhs) { return !rhs.has_value(); })); + + modulos = lhs->modulos; +} + +SmallVector +PtrState::createStackedCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2); + assert(modulos[0].has_value() && !modulos[1].has_value()); + + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // + // We do not support cases where the target offset has already overflown the + // number of rows. See side-by-side wraparound for details. + // + ////////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + ShapedType::kDynamic, // Row is dynamic, in most cases, this should be + // the same as the original row. The last chunk + // may be smaller due to wrapping around. + resultShape[1], // Col stays the same. + }, + true /*useDynamicStrides*/); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value strideRow = ofrToIndexValue(strides[0], loc, rewriter); + Value strideCol = ofrToIndexValue(strides[1], loc, rewriter); + + Value modRow = rewriter.create( + loc, rewriter.getIndexType(), modulos[0]->size); + + // First chunk + Value wrappedAroundOff = + rewriter.create(loc, targetOffset, strideRow); + Value clampedOff = rewriter.create(loc, modRow, strideRow); + clampedOff = + rewriter.create(loc, clampedOff, wrappedAroundOff); + Value d1 = rewriter.create(loc, clampedOff, targetOffset); + d1 = rewriter.create(loc, d1, strideRow); + + SmallVector sizes1{d1, colSize}; + memref::ReinterpretCastOp cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, + ValueRange{strideRow, strideCol}); + + // Second chunk + Value d2 = rewriter.create(loc, rowSize, d1); + SmallVector sizes2{d2, colSize}; + memref::ReinterpretCastOp cast2 = rewriter.create( + loc, resultType, source, wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); + + return {cast1, cast2}; +} + +SmallVector +PtrState::createSideBySideCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2 && !modulos[0].has_value() && modulos[1].has_value()); + + // Accumulate final offset + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Note: We do not support cases where the target has already overflown the + // number of columns! This is because in PtrAnalysis, the offset has already + // been collapsed into a single dimension, so it is ambiguous to determine + // whether the offset actually overflows or just refers to an element on the + // subsequent rows. + // + // Same limitations apply to the stacked wraparound case. + // + ////////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // nextOffset = x + colSize + // clampedOffset = min(nextOffset, N) + // d1 = clampedOffset - x + // + ////////////////////////////////////////////////////////////////////////////// + + SmallVector casts; + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + resultShape[0], // Row stays the same + ShapedType::kDynamic // Column is dynamic, in most cases, this should + // be the same as the original column. The last + // chunk may be smaller due to wrapping around. + }, + true /*useDynamicStrides*/); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value modN = rewriter.create(loc, rewriter.getIndexType(), + modulos[1]->size); + + Value x = rewriter.create(loc, targetOffset, modN); + Value y = rewriter.create(loc, targetOffset, x); + + SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); + + // First chunk + Value nextOffset = rewriter.create(loc, x, colSize); + Value clampedOffset = rewriter.create(loc, nextOffset, modN); + Value d1 = rewriter.create(loc, clampedOffset, x); + SmallVector sizes1{rowSize, d1}; + + auto cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, strideVals); + + // Second chunk + Value d2 = rewriter.create(loc, colSize, d1); + SmallVector sizes2{rowSize, d2}; + + auto cast2 = rewriter.create( + loc, resultType, source, y, sizes2, strideVals); + + return {cast1, cast2}; +} + +memref::ReinterpretCastOp +PtrState::createCastOp(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const { + // Accumulate final offset + OpFoldResult targetOffset = accumulateTargetOffset(loc, rewriter); + + // Create result MemRefType + SmallVector staticOffset; + SmallVector dynamicOffset; + dispatchIndexOpFoldResult(targetOffset, dynamicOffset, staticOffset); + + auto resultType = + getResultMemrefType(rewriter.getContext(), staticOffset[0], resultShape); + + // Create reinterpret cast + return rewriter.create( + loc, resultType, source, targetOffset, sizes, strides); +} + +void PtrAnalysis::visitOperandAdd( + arith::AddIOp addOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(addOp.getLhs(), lhsState, loc, rewriter, knownPtrs); + + PtrState rhsState; + visitOperand(addOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + assert(0 && "Current do not support this pattern: a + arange(0, K) % M"); + } + + state.addState(lhsState, rhsState, loc, rewriter); +} + +void PtrAnalysis::visitOperandMul( + arith::MulIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(mulOp.getLhs(), lhsState, loc, rewriter, knownPtrs); + + PtrState rhsState; + visitOperand(mulOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + + state.mulState(lhsState, rhsState, loc, rewriter); +} + +void PtrAnalysis::visitOperandRem( + arith::RemSIOp remOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState rhsState; + visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + assert(rhsState.scalar); + + visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs); + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + assert(llvm::all_of(state.modulos, + [](auto modState) { return !modState.has_value(); }) && + "No support for multiple modulo within an expression"); + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.modulos.back() = ModuloState{rhsState.scalar}; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.modulos[1] = ModuloState{rhsState.scalar}; + } else if (shape[1] == 1) { + state.modulos[0] = ModuloState{rhsState.scalar}; + } else { + assert(false && "Taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + } + } else { + assert(false && "Unsupported modulo pattern"); + } +} + +void PtrAnalysis::visitOperandMakeRange( + triton::MakeRangeOp rangeOp, PtrState &state, Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back(rewriter.getIndexAttr(start)); + state.sizes.push_back(rewriter.getIndexAttr(shape[0])); + state.strides.push_back(rewriter.getIndexAttr(stride)); + state.modulos.push_back(std::nullopt); +} + +void PtrAnalysis::visitOperandExpandDims( + triton::ExpandDimsOp expandDimsOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + // `getSrc` now returns a TypedValue of RankedTensorType. We modify these + // operands in-place and turn them into memrefs in loops, so we have to bypass + // the cast by using getSrcMutable. These are temporary fix only since + // we will be moving over to StructuredPtrAnalysis soon which separate out the + // memref conversion. + visitOperand(expandDimsOp.getSrcMutable().get(), state, loc, rewriter, + knownPtrs); + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, rewriter.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, rewriter.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, rewriter.getIndexAttr(0)); + state.modulos.insert(state.modulos.begin() + axis, std::nullopt); +} + +void PtrAnalysis::visitOperandBroadcast( + triton::BroadcastOp broadcastOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + // `getSrc` now returns a TypedValue of RankedTensorType. We modify these + // operands in-place and turn them into memrefs in loops, so we have to bypass + // the cast by using getSrcMutable. These are temporary fix only since + // we will be moving over to StructuredPtrAnalysis soon which separate out the + // memref conversion. + auto src = broadcastOp.getSrcMutable().get(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + visitOperand(src, state, loc, rewriter, knownPtrs); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + state.sizes[i] = rewriter.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } +} + +void PtrAnalysis::visitOperandSplat( + triton::SplatOp splatOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + visitOperand(src, state, loc, rewriter, knownPtrs); + + if (isa(src.getType())) { + for (auto s : dstShape) { + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(s)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } + } else { + // src is a memref that represent a scalar pointer; it should have + // one dimension of size 1. This happens inside a for loop that + // originally has an init arg that is a tensor of pointers; this arg + // would have been replaced by rewriteForOp. + auto srcType = cast(src.getType()); + assert(srcType.getRank() == 1 && state.getRank() == 1 && + "splat MemRef source should have rank 1"); + assert(srcType.getShape()[0] == 1 && + getIntAttr(state.sizes[0]).value() == 1 && + "splat MemRef source should have size 1"); + + // Stride[0] will have value of 1 set in visitOperandAddPtr. This + // value will be represented by a constOp. Clear this value. + state.strides[0] = rewriter.getIndexAttr(0); + + for (auto [i, s] : llvm::enumerate(dstShape)) { + if (i == 0) { + state.sizes[i] = rewriter.getIndexAttr(s); + continue; + } + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(s)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) + state.offsets[0] = state.scalar; +} + +void PtrAnalysis::visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto remappedValue = rewriter.getRemappedValue(makeTensorPtrOp); + if (auto castOp = remappedValue.getDefiningOp()) { + visitOperandReintCast(castOp, state, loc, rewriter, knownPtrs); + } else { + llvm_unreachable("Expect value to me mapped to a memref.reinterpret_cast"); + } +} + +void PtrAnalysis::visitOperandAddptr( + triton::AddPtrOp addptrOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState ptrState; + visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), rewriter, + knownPtrs); + + PtrState offsetState; + visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), rewriter, + knownPtrs); + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + // Handle the special case when we are in a for loop, ptr is originally a + // scalar pointer but replaced with a memref. In this case, ptrState will have + // rank 1 and offsetState will have rank 0. + // TODO: + // Passing a block argument pointer directly into a for loop not supported + if (ptrState.getRank() == 1 && offsetState.getRank() == 0) { + offsetState.sizes.push_back(rewriter.getIndexAttr(1)); + offsetState.offsets.push_back(offsetState.scalar); + offsetState.strides.push_back(rewriter.getIndexAttr(0)); + offsetState.modulos.push_back(std::nullopt); + } + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + state.addState(ptrState, offsetState, addptrOp.getLoc(), rewriter); +} + +void PtrAnalysis::visitOperandReintCast( + memref::ReinterpretCastOp reintCastOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + state.offsets = reintCastOp.getMixedOffsets(); + state.sizes = reintCastOp.getMixedSizes(); + state.strides = reintCastOp.getMixedStrides(); + state.source = reintCastOp.getSource(); + state.modulos.append(state.sizes.size(), std::nullopt); + + // getMixedOffsets produces staticOffsets (which is the result of collapsing + // multiple dimensions). Populate the rest of the dimensions with zeroes. + assert(state.offsets.size() == 1); + for (size_t i = 1; i < state.sizes.size(); i++) { + state.offsets.push_back(rewriter.getIndexAttr(0)); + } + + // Regular Triton programs cannot express patterns of size 1 and non-zero + // stride; we only set it that way to make memrefs work. Set stride back to + // zero if this scenario detected. + for (size_t i = 0; i < state.strides.size(); i++) { + auto strideIntAttr = getIntAttr(state.strides[i]); + auto sizeIntAttr = getIntAttr(state.sizes[i]); + + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr) { + state.strides[i] = rewriter.getIndexAttr(0); + } + } +} + +void PtrAnalysis::visitOperand( + Value operand, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return; + } + + if (isa(operand.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), operand); + state.scalar = castOp.getResult(); + return; + } + + if (isa(operand.getType())) { + auto remappedPtr = rewriter.getRemappedValue(operand); + assert(remappedPtr); + + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + visitOperandAddptr(cast(op), state, loc, rewriter, + knownPtrs); + } else if (auto makeTensorOp = dyn_cast(op)) { + visitOperandMakeTensorPtr(makeTensorOp, state, loc, rewriter, + knownPtrs); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } + } else { + state.source = remappedPtr; + } + return; + } + + if (auto op = operand.getDefiningOp()) { + visitOperandAdd(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandMul(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandMakeRange(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandBroadcast(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandSplat(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandExpandDims(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandAddptr(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandConstSplat(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandRem(op, state, loc, rewriter, knownPtrs); + } else { + operand.dump(); + llvm_unreachable("encountered addptr operand produced by an " + "unsupported operation"); + } +} + +void PtrAnalysis::visitOperandConstSplat( + arith::ConstantOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType)); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = rewriter.getIndexAttr(value.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(rewriter, constAttr, + rewriter.getIndexType(), loc); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (size_t i = 0; i < resultType.getShape().size(); i++) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(rewriter.getIndexAttr(0)); + } + + state.sizes.push_back(rewriter.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } +} + +void PtrAnalysis::rewriteAddptrOp( + triton::AddPtrOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs) { + // any inserted instruction should be before this addptr + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + PtrState state; + visitOperandAddptr(op, state, op.getLoc(), rewriter, knownPtrs); + + // If the result is a scalar pointer, visitOperandAddptr will not populate + // sizes, strides, and offsets. We need to do it here. + if (state.sizes.size() == 0) { + state.sizes.push_back(rewriter.getIndexAttr(1)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.offsets.push_back(state.scalar); + state.modulos.push_back(std::nullopt); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + if (auto shapedType = dyn_cast(op.getResult().getType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(state.getRank() == 1); + } + + knownPtrs[op.getResult()] = state; + + // If there are dimensions with size 1 and stride 0, replace 0 stride with the + // product of sizes of all lower dimensions. This avoids creating memref with + // zero stride. Note that we store the unmodified state into knownPtrs, since + // any following pointer arithmetic operations should use the original 0 + // stride. + auto accum_size = 1; + for (int i = state.sizes.size() - 1; i >= 0; i--) { + auto strideIntAttr = getIntAttr(state.strides[i]); + auto sizeIntAttr = getIntAttr(state.sizes[i]); + + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr && strideIntAttr.value() == 0) + state.strides[i] = rewriter.getIndexAttr(accum_size); + + accum_size *= sizeIntAttr.value(); + } + + Value src; + + if (llvm::any_of(state.modulos, [](auto mod) { return mod.has_value(); })) { + assert(state.modulos.size() == 2); + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + SmallVector casts; + StringRef type; + + if (!state.modulos[0].has_value() && state.modulos[1].has_value()) { + casts = state.createSideBySideCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundSideBySide; + } else if (state.modulos[0].has_value() && !state.modulos[1].has_value()) { + casts = state.createStackedCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundStacked; + } else { + assert(false && "not supported"); + } + + auto resultType = state.getResultMemrefType( + rewriter.getContext(), ShapedType::kDynamic, resultShape); + + UnrealizedConversionCastOp combinedCast = + rewriter.create( + op.getLoc(), resultType, + ValueRange{casts[0].getResult(), casts[1].getResult(), + op.getResult()}); + + combinedCast->setAttr(ModuloState::WraparoundAttr, + rewriter.getStringAttr(type)); + + src = combinedCast.getResult(0); + + LLVM_DEBUG({ + llvm::dbgs() << "combine cast for split pointers:\n"; + combinedCast.getOperation()->print( + llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + } else { + memref::ReinterpretCastOp castOp = + state.createCastOp(resultShape, op.getLoc(), rewriter); + + src = castOp.getResult(); + + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + } + + state.source = src; + rewriter.replaceOp(op, src); + rewriter.restoreInsertionPoint(origIp); +} + +void PtrAnalysis::rewriteAdvanceOp( + triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs) { + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + PtrState ptrState; + visitOperand(op.getOperand(0), ptrState, loc, rewriter, knownPtrs); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (auto [increment, offset, stride] : + llvm::zip(incrementOffsets, ptrState.offsets, ptrState.strides)) { + Value offsetValue; + if (auto offsetIntAttr = getIntAttr(offset)) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + offsetValue = constOp.getResult(); + } else { + offsetValue = cast(offset); + } + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), increment); + auto mulOp = rewriter.create(loc, castOp.getResult(), + cast(stride)); + auto addOp = + rewriter.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + ptrState.offsets.clear(); + + for (auto offset : newOffsets) { + ptrState.offsets.push_back(offset); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(ptrState.getRank() == 1); + } + + auto newOp = ptrState.createCastOp(resultShape, loc, rewriter); + + rewriter.replaceOp(op, newOp.getResult()); + + knownPtrs[newOp.getResult()] = ptrState; +} + +void PtrAnalysis::rewriteYieldOp( + scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &knownPtrs) { + // any inserted instruction should be before this yield + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + + SmallVector initArgState; + SmallVector operands(adaptor.getOperands()); + // Track the second chunks of modulo pointers so that we can append them to + // the yield results + SmallVector moduloSecondChunks; + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // PtrState for those values. + for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { + if (auto mappedV = rewriter.getRemappedValue(v)) { + // If this value is a tensor of pointers produced by AddPtrOp, + // we should have already converted to a ReinterpretCastOp without + // layout information for the normal cases, or to an + // UnrealizedConversionCastOp for the split pointer case. + if (v.getDefiningOp() || + v.getDefiningOp() || + v.getDefiningOp()) { + if (auto castOp = mappedV.getDefiningOp()) { + assertValidUnrealizedCast(castOp); + auto castInputs = castOp.getInputs(); + v = castOp.getResult(0); + operands[i] = castInputs[0]; + moduloSecondChunks.push_back(castInputs[1]); + } else if (auto castOp = + mappedV.getDefiningOp()) { + v = castOp; + } else { + llvm_unreachable("mapped value defined by an unexpected op"); + } + } else { + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. + + // TODO: + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported + if (isa(mappedV.getType()) && + isa( + dyn_cast(mappedV.getType()).getElementType())) + llvm_unreachable("unsupported scenario where a value is a tensor of " + "pointers but not produced by AddPtrOp"); + v = mappedV; + } + } + + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) + continue; + auto thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto reintCastOp = v.getDefiningOp(); + auto unrealizedCastOp = v.getDefiningOp(); + + assert( + reintCastOp || + (unrealizedCastOp && + unrealizedCastOp->hasAttr(ModuloState::WraparoundAttr)) || + (isa(v.getType()) && + isa(dyn_cast(v.getType()).getElementType()))); + + PtrState state; + if (reintCastOp) { + visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, + knownPtrs); + } else if (unrealizedCastOp) { + assertValidUnrealizedCast(unrealizedCastOp); + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + knownPtrs); + } else { + visitOperand(v, state, op.getLoc(), rewriter, knownPtrs); + } + initArgState.push_back(state); + } + + // For each of the PtrState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.offsets) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. + if (auto sIntAttr = getIntAttr(s)) { + assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(cast(s)); + } + } + + for (auto s : state.strides) { + assert(!getIntAttr(s) && "PtrState strides for yield within for " + "loop not expected to be " + "attribute."); + operands.push_back(cast(s)); + } + } + + for (auto chunk : moduloSecondChunks) { + operands.push_back(chunk); + } + + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +// From an unrealized_conversion_cast which takes in two reinterpret_casts +// representing two chunks, we need to get back the full pointer state. We +// cannot rebuild the original state from the two reinterpret_casts similarly to +// the normal case. To solve this, we attach the original addptr as the third +// operand to the unrealized_cast so that we can manually rebuild the state. +void PtrAnalysis::visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assertValidUnrealizedCast(op); + + auto origPtr = op.getInputs()[2]; + if (knownPtrs.contains(origPtr)) { + state = knownPtrs.at(origPtr); + } else { + visitOperandAddptr(origPtr.getDefiningOp(), state, loc, + rewriter, knownPtrs); + } +} + +struct ModuloChunkInitArg { + Value reinterpretCast = nullptr; + // where in the init args is the first chunk placed + size_t initArgIndex = -1; +}; + +void PtrAnalysis::rewriteForOp( + scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &knownPtrs) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + // If we have a load op that uses a modulo pointer, we need to insert both of + // the memref chunks to the init args. We reuse the sizes from the original + // memrefs. This data structure keeps track of where these additional init + // args should be inserted. + // + // As an example, if we have a 2D memrefs being split, we first put the first + // chunk in the order as it appears. Then, once all of the original init args + // are processed, we insert their offsets and strides, and finally the second + // chunk. + SmallVector, PtrState>, + 6> + moduloStates; + + // Amongst the init args, track the indices that map to the first chunk of a + // modulo pair. This is used to distinguish between the normal + // reinterpret_casts whose return types need to be rewritten to match what the + // for loop is yielding. + DenseSet moduloInitArgIndices; + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = rewriter.getRemappedValue(arg); + memref::ReinterpretCastOp reintCastOp; + UnrealizedConversionCastOp unrealizedCastOp; + + // If this init arg is supposed to be remapped, use the remapped + // value instead. In addition, if this init arg is a memref created + // by a reinterpret_cast or a tensor of index, there is a chance that + // it will be used in addptr. Create PtrState for each such init arg. + if (mappedV) { + // TODO: + // Passing a block argument pointer directly into a for loop not + // supported. + assert(!(dyn_cast(mappedV) && + isa(mappedV.getType())) && + "cannot take pointer block argument as init arg for for loop"); + if (auto op = mappedV.getDefiningOp()) { + reintCastOp = op; + newInitArgs.push_back(mappedV); + } else if (auto op = + mappedV.getDefiningOp()) { + assertValidUnrealizedCast(op); + unrealizedCastOp = op; + auto inputs = unrealizedCastOp.getInputs(); + + SmallVector initArgData{ + ModuloChunkInitArg{inputs[0], i}, + ModuloChunkInitArg{inputs[1]}, + }; + + moduloInitArgIndices.insert(i); + moduloStates.push_back( + std::make_tuple(unrealizedCastOp, initArgData, PtrState{})); + + newInitArgs.push_back(inputs[0]); + } else { + newInitArgs.push_back(mappedV); + } + + } else { + newInitArgs.push_back(arg); + } + + auto indexTensor = + isa(arg.getType()) && + isa(dyn_cast(arg.getType()).getElementType()); + + if (!unrealizedCastOp && !reintCastOp && !indexTensor) + continue; + + PtrState state; + if (reintCastOp) { + visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } else if (unrealizedCastOp) { + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + std::get<2>(moduloStates.back()) = state; + } else { + visitOperand(arg, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } + + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + } + + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the PtrState recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.offsets[j] = constOp.getResult(); + } else { + newInitArgs.push_back(cast(s)); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.strides[j] = constOp.getResult(); + } else { + newInitArgs.push_back(cast(s)); + } + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, state)); + levelToBlockArgIndex[level].insert(i); + + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 + // * s1)>> + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + // + // For init args that are the first chunk of a modulo pair, there is + // no need for the type to be rewritten because the strides and + // offsets are already dynamic. + if (!moduloInitArgIndices.contains(i) && + newInitArgs[i].getDefiningOp()) { + SmallVector resultShape; + for (auto s : state.sizes) { + auto sIntAttr = getIntAttr(s); + assert(sIntAttr && "expected constant size"); + resultShape.push_back(sIntAttr.value()); + } + auto castOp = state.createCastOp(resultShape, op.getLoc(), rewriter); + + LLVM_DEBUG({ + llvm::dbgs() << "new reinterpret_cast with dynamic sizes " + "and offsets:"; + castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + newInitArgs[i] = castOp.getResult(); + } + } + + // Pass in the second chunk of each modulo pair + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + chunkData[1].initArgIndex = newInitArgs.size(); + newInitArgs.push_back(chunkData[1].reinterpretCast); + } + + rewriter.restoreInsertionPoint(origIp); + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = rewriter.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping mapping; + mapping.map(op.getInductionVar(), iv); + mapping.map(op.getInitArgs(), newInitArgs); + mapping.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, mapping); + } + + // Load op is lowered independent of the pointer, if we have a split + // pointer due to modulo, we need to "logically combine" these two + // memrefs into a single one using unrealized_cast_op. This way, when + // lowering the load, the pattern can detect if additional copies are + // inserted. When we are in a loop, it is more complicated because we + // have to insert a new unrealized_cast_op that combines the two memrefs + // in the init arg list. In addition, because init args hold no offset + // and size information, we have to manually insert two additional + // reinterpret_cast ops as input to this unrealized_cast_op so that the + // load have enough information to generate the corresponding copy. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(b.getBlock()); + + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + SmallVector newReinterpretCasts; + for (auto &chunk : chunkData) { + newReinterpretCasts.push_back(args[chunk.initArgIndex]); + } + + auto combinedCast = b.create( + loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, + unrealizedCastOp->getAttrs()); + + args[chunkData[0].initArgIndex].replaceUsesWithIf( + combinedCast.getResult(0), [](OpOperand &operand) { + assert(!isa(operand.getOwner()) && + "Storing to split pointers not supported"); + return isa(operand.getOwner()); + }); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrState fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + for (auto [i, state] : knownPtrsTmp) { + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs.insert(std::make_pair(key, state)); + } + assert(static_cast(cnt + moduloStates.size()) == + newOp.getRegionIterArgs().size() && + "expect to remap all new block args"); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + rewriter.replaceOp(op, resultsToReplaceWith); + + // Update the loop body. Manually invoke the rewrite logic on addptr and yield + // in the loop body, so we can take advantage of the states we built up + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto addptrOp = dyn_cast(bodyOp)) { + rewriteAddptrOp(addptrOp, rewriter, knownPtrs); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, knownPtrs); + } else if (auto forOp = dyn_cast(bodyOp)) { + // TODO: + // Nested for loops are not supported at the moment + assert(0 && "nested loops currently not supported"); + // rewriteForOp(forOp, rewriter, levelToBlockArgIndex, level+1, + // knownPtrs); levelToBlockArgIndex.erase(level+1); + } + } + + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + rewriteYieldOp(yieldOp, rewriter, levelToBlockArgIndex, level, knownPtrs); + } + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +Value PtrAnalysis::getScalarMemRef(Value ptr, Value memRef, const Location loc, + ConversionPatternRewriter &rewriter) { + assert(cast(ptr.getType()) && "expected scalar pointer"); + + // If the pointer is generated by tt.addptr, we will have already inserted an + // ReinterpretCastOp to cast its type from tt.ptr to unranked memref. Return + // the result. + if (ptr.getDefiningOp()) { + if (auto castOp = memRef.getDefiningOp()) { + return castOp.getResult(); + } else { + llvm_unreachable("pointer value is defined by an unexpected op"); + } + } + + assert(isa(ptr) && + "pointer is neither produced by addptr nor a block argument"); + PtrState state; + state.source = memRef; + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(1)); + state.strides.push_back(rewriter.getIndexAttr(1)); + state.modulos.push_back(std::nullopt); + auto castOp = state.createCastOp(SmallVector(1, 1), loc, rewriter); + return castOp.getResult(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp new file mode 100644 index 000000000..62e450808 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp @@ -0,0 +1,220 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/UseAnalysis.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace triton; +using namespace dataflow; + +#define DEBUG_TYPE "triton-use-analysis" + +//===----------------------------------------------------------------------===// +// Use Analysis +// Note that logic below should evolve with triton-to-affine pass +//===----------------------------------------------------------------------===// +LogicalResult +triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) { + // If an op only produces pointer, all its operands are used as meta data. + // This accounts for scenarios such as addptr in a loop whose result is + // yielded. In this case, if the loop returns data tensors, addptr will be + // marked correctly as meta use. + if (op->getResults().size() == 1) { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (resultType && isa(resultType.getElementType())) { + for (auto opnd : operands) + propagateUse(opnd, UseType::MetaUse); + } + } + + TypeSwitch(op) + .Case([&](auto load) { + propagateUse(operands[0], UseType::MetaUse); + auto mask = load.getMask(); + auto other = load.getOther(); + if (mask) { + assert(mask != other && "mask and other cannot be the same"); + propagateUse(operands[1], UseType::MetaUse); + } + if (other) { + // TODO: + // More complicated patterns that generate other is unsupported. + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto store) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = store.getValue(); + auto mask = store.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto dot) { + propagateResults(operands[0], results); + propagateResults(operands[1], results); + + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) + splat = opc.template getDefiningOp(); + + if (opc && splat && splat.getSrc().getDefiningOp()) + propagateUse(operands[2], UseType::MetaUse); + else + propagateUse(operands[2], UseType::DataUse); + }) + .Default([&](Operation *op) { + // this condition account for tt.addptr + for (auto operand : operands) { + propagateResults(operand, results); + } + }); + return success(); +} + +LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { + MLIRContext *context = funcOp.getContext(); + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(funcOp))) + return failure(); + + // Walk the func op, convert tags on operands to tags on operations + funcOp.walk([&](Operation *op) { + UseType useType = UseType::Undefined; + for (auto result : op->getResults()) { + auto use = solver.lookupState(result); + assert(use && "Lattice value not found"); + auto thisUseType = use->type; + if (thisUseType == UseType::Undefined) + continue; + if (useType == UseType::Undefined) + useType = thisUseType; + if (thisUseType == UseType::MixUse || thisUseType != useType) { + useType = UseType::MixUse; + break; + } + } + + if (useType == UseType::Undefined) { + LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); + return; + } else if (useType == UseType::MetaUse) { + assert(op->getNumResults() == 1 && + "Ops used for meta computation are expected to have one result"); + // Only set the tag if the operation uses tensors + if (isa(op->getResult(0).getType())) { + // Setting tag for erasing op later + op->setAttr("MetaUse", UnitAttr::get(context)); + } + return; + } else if (useType == UseType::DataUse) { + LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); + return; + } + + assert(useType == UseType::MixUse); + + // If the operation only produces scalars, no need to clone it + bool shapedResult = true; + for (auto result : op->getResults()) + shapedResult &= isa(result.getType()); + if (!shapedResult) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Value has MixUse. However, the operation may or may not have direct + // MetaUse. E.g., it may only have MixUse, or only have MixUse and + // DataUse. + // - If the operation has direct MetaUse, clone it, tag the clone as + // MetaUse only and point meta users to use the clone. + // - If not, do nothing; this operation will still be materlized. + llvm::SetVector metaUsers; + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + TypeSwitch(user) + .Case([&](auto load) { + auto ptr = load.getPtr(); + auto mask = load.getMask(); + auto other = load.getOther(); + if (result == ptr || result == mask || result == other) + metaUsers.insert(user); + }) + .Case([&](auto store) { + auto ptr = store.getPtr(); + auto mask = store.getMask(); + if (result == ptr || result == mask) + metaUsers.insert(user); + }) + .Case([&](auto dot) { + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) + splat = opc.template getDefiningOp(); + + if (opc && splat && + splat.getSrc().getDefiningOp()) + metaUsers.insert(user); + }) + .Default([&](Operation *op) { + // if all output of user are used as meta data, user is a meta + // user. This condition account for addptr, or an addi whose + // output only feeds into addptr + bool allMeta = true; + for (auto res : op->getResults()) { + auto resUse = solver.lookupState(res); + if (resUse->type != UseType::MetaUse) { + allMeta = false; + break; + } + } + if (allMeta) + metaUsers.insert(user); + }); + } + } + + // If the operation doesn't have direct meta users, no need to clone it + if (metaUsers.empty()) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Clone the operation; switch all meta users to use the clone + OpBuilder builder(op); + auto clone = builder.clone(*op); + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + + // Setting tag for erasing op later + clone->setAttr("MetaUse", UnitAttr::get(context)); + + for (auto [res_i, result] : llvm::enumerate(op->getResults())) + for (auto user : metaUsers) + for (auto &operand : user->getOpOperands()) + if (operand.get() == result) + operand.set(clone->getResult(res_i)); + }); + + return success(); +} diff --git a/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt b/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt new file mode 100644 index 000000000..0683754ca --- /dev/null +++ b/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ZTCAnalysisStructured + PtrAnalysis.cpp + + DEPENDS + TritonAnalysis + TritonTableGen + TritonStructuredTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonStructuredIR + MLIRAnalysis +) diff --git a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp new file mode 100644 index 000000000..ee98c6c56 --- /dev/null +++ b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp @@ -0,0 +1,1395 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "triton-ptr-analysis" + +namespace mlir { + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + builder, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +namespace tts { + +int32_t PtrState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + shape.size() == offsets.size()); + return offsets.size(); +} + +bool PtrState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +bool PtrState::hasModulo() const { + for (int32_t i = 0; i < getRank(); i++) { + if (dimHasModulo(i)) { + return true; + } + } + return false; +} + +bool PtrState::dimHasModulo(uint32_t dim) const { + assert( + !isBlockPtr() && + "Analysis should not check modulo if PtrState describes block pointer"); + + assert(dim < getRank()); + + auto intAttr = getIntAttr(shape[dim]); + if (!intAttr.has_value()) { + return true; + } + + return intAttr.value() != 0; +} + +bool PtrState::isBlockPtr() const { return !order.empty(); } + +LogicalResult PtrState::addState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + auto loc = op->getLoc(); + + if (lhsState.source && rhsState.source) { + op->emitRemark( + "PtrAnalysis: do not support adding two pointer states that both " + "have base pointers"); + return failure(); + } + + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.getRank(); i++) { + auto newOffset = + addOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, builder); + offsets.push_back(newOffset); + + auto newStride = + addOFRs(lhsState.strides[i], rhsState.strides[i], loc, builder); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + } + + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark("PtrAnalysis: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + if (lhsState.hasModulo() || rhsState.hasModulo()) { + // visitOperandSplat and visitOperandExpandDims should enforce below + assert(lhsState.getRank() <= 2); + } + + // dealing with modulo: + // - If lhs has no modulo, skip + // - If rhs has zero offset on dim i, we can just use lhs's modulo + // - If i == 0 and rhs is the result of a splat, we will allow the add. This + // is because the user may be trying to express adding a constant offset to + // increment dim1, but pointer analysis cannot differentiate dim1 vs dim0 in + // this case. + // - Else, the analysis fails + + // An example for the 3rd condition above can look like: + // %0 = tt.splat %scalar + // %1 = tt.splat %ptr + // %2 = tt.arange + // %3 = arith.remsi %2, %size + // %4 = tt.addptr %1, %3 + // %5 = tt.addptr %4, %0 + // %5 may also occur in a loop to increment %4 every iteration. + + // Note that this is not bullet-proof. E.g., broken IR can actually increment + // dim0 while dim0 already has modulo, since Triton offsets are element-wise + // and not in unit of lower dimensions. However, this is highly unlikely but + // the analysis will provide wrong result. Hence we provide a warning in this + // case. + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (rhs->hasModulo()) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->getRank(); i++) { + if (!lhs->dimHasModulo(i)) { + shape.push_back(lhs->shape[i]); + } else if (hasConstZero(rhs->offsets[i])) { + shape.push_back(lhs->shape[i]); + } else if (i == 0 && lhs->getRank() == 2 && rhs->scalar) { + shape.push_back(lhs->shape[1]); + shape.push_back(lhs->shape[0]); + op->emitWarning( + "PtrAnalysis: allowing adding pointer state with modulo in dim 0 to " + "another pointer state with offset in dim 0.\nPlease verify the " + "operand that contains a scalar is meant to increment pointers in " + "dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION " + "RESULTS.\n\nTo avoid this warning, use expand_dims (instead of " + "splat) to explicitly specify which dimension contains the scalar."); + break; + } else { + op->emitRemark( + "PtrAnalysis: do not support adding to operand with modulo"); + return failure(); + } + } + + return success(); +} + +void PtrState::dump() const { + llvm::dbgs() << "PtrState: "; + if (source) { + llvm::dbgs() << "source: " << source << "\n"; + } + if (scalar) { + llvm::dbgs() << "scalar: " << scalar << "\n"; + } + + llvm::dbgs() << "offsets: "; + llvm::interleave(offsets, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nstrides: "; + llvm::interleave(strides, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nsizes: "; + llvm::interleave(sizes, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nshape: "; + llvm::interleave(shape, llvm::dbgs(), "\n"); + llvm::dbgs() << "\norder: "; + llvm::interleave(order, llvm::dbgs(), "\n"); + llvm::dbgs() << "\n"; +} + +LogicalResult PtrState::mulState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + auto loc = op->getLoc(); + + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + if (lhsState.source && rhsState.source) { + op->emitRemark("PtrAnalysis: do not support multiplying base pointers"); + return failure(); + } + + // currently do not support both tensors are effectively non-scalar + if (!lhsState.scalar && !rhsState.scalar) { + op->emitRemark( + "PtrAnalysis: only support multiplying pointer states when one of " + "them represent a scalar"); + return failure(); + } + + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + if (lhsState.scalar && rhsState.scalar) { + scalar = + builder.create(loc, lhsState.scalar, rhsState.scalar); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, builder); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, builder); + OpFoldResult newShape = + mulOFRValue(lhs->shape[i], rhs->scalar, loc, builder); + offsets.push_back(newOffset); + strides.push_back(newStride); + shape.push_back(newShape); + sizes.push_back(lhs->sizes[i]); + } + + if (rhs->hasModulo()) { + op->emitRemark( + "PtrAnalysis: do not support multiplying pointer states that has " + "modulos"); + return failure(); + } + + return success(); +} + +tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc) { + SmallVector staticSizes; + for (size_t i = 0; i < getRank(); i++) { + auto s = getIntAttr(sizes[i]); + assert(s.has_value()); + staticSizes.push_back(s.value()); + } + + auto op = builder.create( + loc, source, staticSizes, strides, offsets, shape, order); + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::make_tensor_ptr:\n"; + op->dump(); + }); + + return op; +} + +LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrState &state, + const Location loc, + OpBuilder &builder) { + PtrState lhsState; + if (visitOperand(addOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrState rhsState; + if (visitOperand(addOp.getRhs(), rhsState, loc, builder).failed()) + return failure(); + + // Checking for higher dimension is done in addState below + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + addOp->emitRemark( + "PtrAnalysis: do not support this pattern: a + arange(0, K) % M"); + return failure(); + } + + return state.addState(lhsState, rhsState, addOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrState &state, + const Location loc, + OpBuilder &builder) { + PtrState lhsState; + if (visitOperand(mulOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrState rhsState; + if (visitOperand(mulOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + return state.mulState(lhsState, rhsState, mulOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrState rhsState; + if (visitOperand(remOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + if (!rhsState.scalar) { + remOp->emitRemark("PtrAnalysis: only support cases when rhs of remainder " + "contains scalar"); + return failure(); + } + + if (visitOperand(remOp.getLhs(), state, loc, builder).failed()) { + return failure(); + } + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + if (state.hasModulo()) { + remOp->emitRemark( + "PtrAnalysis: do not support multiple modulo within an expression"); + return failure(); + } + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.shape.back() = rhsState.scalar; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.shape[1] = rhsState.scalar; + } else if (shape[1] == 1) { + state.shape[0] = rhsState.scalar; + } else { + remOp->emitRemark( + "PtrAnalysis: taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + return failure(); + } + } else { + remOp->emitRemark("PtrAnalysis: unsupported modulo pattern"); + return failure(); + } + return success(); +} + +LogicalResult PtrAnalysis::visitOperandExtSI(arith::ExtSIOp extOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + return visitOperand(extOp.getIn(), state, loc, builder); +} + +LogicalResult PtrAnalysis::visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back(builder.getIndexAttr(start)); + state.sizes.push_back(builder.getIndexAttr(shape[0])); + state.strides.push_back(builder.getIndexAttr(stride)); + state.shape.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + if (visitOperand(expandDimsOp.getSrc(), state, loc, builder).failed()) { + return failure(); + } + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, builder.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, builder.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, builder.getIndexAttr(0)); + state.shape.insert(state.shape.begin() + axis, builder.getIndexAttr(0)); + + if (state.hasModulo() && state.getRank() > 2) { + expandDimsOp->emitRemark( + "PtrAnalysis: unsupported scenario where expand_dims result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + + if (!isa(src.getType())) { + broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); + return failure(); + } + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + for (size_t i = 0; i < dstShape.size(); i++) { + if (srcShape[i] == dstShape[i]) { + continue; + } else if (srcShape[i] < dstShape[i]) { + state.sizes[i] = builder.getIndexAttr(dstShape[i]); + } else { + llvm_unreachable("unexpected dimensions used in broadcast"); + } + } + return success(); +} + +LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + if (isa(src.getType())) { + for (auto s : dstShape) { + state.offsets.push_back(builder.getIndexAttr(0)); + state.sizes.push_back(builder.getIndexAttr(s)); + state.strides.push_back(builder.getIndexAttr(0)); + state.shape.push_back(builder.getIndexAttr(0)); + } + } else { + splatOp->emitRemark("PtrAnalysis: unsupported splat pattern"); + return failure(); + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) + state.offsets[0] = state.scalar; + + if (state.hasModulo() && state.getRank() > 2) { + splatOp->emitRemark("PtrAnalysis: unsupported scenario where splat result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrState ptrState; + if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) + .failed()) { + // assert(0); + return failure(); + } + + PtrState offsetState; + if (visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), + builder) + .failed()) { + return failure(); + } + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + return state.addState(ptrState, offsetState, addptrOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandConstSplat(arith::ConstantOp op, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType)); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (size_t i = 0; i < resultType.getShape().size(); i++) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(builder.getIndexAttr(0)); + } + + state.sizes.push_back(builder.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(builder.getIndexAttr(0)); + state.shape.push_back(builder.getIndexAttr(0)); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + + assert(state.isEmpty()); + state.source = makeTPtrOp.getBase(); + state.offsets = makeTPtrOp.getMixedOffsets(); + state.sizes = makeTPtrOp.getMixedSizes(); + state.strides = makeTPtrOp.getMixedStrides(); + state.shape = makeTPtrOp.getMixedShape(); + state.order = SmallVector(makeTPtrOp.getOrder()); + + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + state.source = makeTPtrOp.getBase(); + + if (makeTPtrOp.getOrder().empty()) { + makeTPtrOp->emitRemark( + "PtrAnalysis: expect tt.make_tensor_ptr to have order field set"); + return failure(); + } + + auto resType = cast(makeTPtrOp.getResult().getType()); + auto pointeeType = cast(resType.getPointeeType()); + auto shape = pointeeType.getShape(); + + for (int64_t i = 0; i < pointeeType.getRank(); i++) { + state.sizes.push_back(builder.getIndexAttr(shape[i])); + + auto strideCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + state.strides.push_back(strideCst.getResult()); + + auto offsetCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + + auto scaledOffset = builder.create( + loc, offsetCst.getResult(), strideCst.getResult()); + state.offsets.push_back(scaledOffset.getResult()); + + auto shapeCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); + state.shape.push_back(shapeCst.getResult()); + } + state.order = SmallVector(makeTPtrOp.getOrder()); + assert(state.isBlockPtr() && + "tt.make_tensor_ptr pointer state should describe a block pointer"); + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandForOp(scf::ForOp forOp, Value operand, + PtrState &state, + const Location loc, + OpBuilder &builder) { + + auto it = llvm::find(forOp->getResults(), operand); + auto index = std::distance(forOp->getResults().begin(), it); + + auto newState = getLoopResultPtrState(forOp, index); + if (failed(newState)) { + forOp.emitError( + "Rewrite for-op failed. Could not find PtrState returned by " + "the loop."); + return failure(); + } + + state = newState.value(); + return success(); +} + +LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, + const Location loc, + OpBuilder &builder) { + + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return success(); + } + + if (isa(operand.getType())) { + OpBuilder::InsertionGuard guard(builder); + if (!isa(operand) && operand.getDefiningOp()) { + builder.setInsertionPointAfter(operand.getDefiningOp()); + } + auto castOp = builder.create( + loc, builder.getIndexType(), operand); + state.scalar = castOp.getResult(); + return success(); + } else if (isa(operand.getType())) { + state.scalar = operand; + return success(); + } + + if (isa(operand.getType())) { + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + return visitOperandAddptr(cast(op), state, loc, + builder); + } else if (auto makeTensorOp = dyn_cast(op)) { + llvm_unreachable("Unexpected operand defining operation tts.make_tptr"); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } + } else { + state.source = operand; + return success(); + } + } + + if (auto op = operand.getDefiningOp()) { + return visitOperandAdd(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMul(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMakeRange(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandBroadcast(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandExpandDims(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandAddptr(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandConstSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandRem(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandExtSI(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandForOp(op, operand, state, loc, builder); + } else if (!operand.getDefiningOp()) { + if (!knownPtrs.contains(operand)) { + return failure(); + } + + // This operand must be an iter-arg of an inner-loop in a multiple-level + // nested loop, which means its PtrState must have already been populated + // during rewriteForOp of the parent loop. + state = knownPtrs[operand]; + return success(); + } else { + llvm::dbgs() << "PtrAnalysis: encountered addptr operand produced by an " + "unsupported operation\n"; + operand.dump(); + return failure(); + } +} + +LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (visitOperandAddptr(op, state, op.getLoc(), builder).failed()) { + return failure(); + } + + knownPtrs[op.getResult()] = state; + + if (isa(op.getPtr().getType())) { + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(op.getResult(), maketptrOp.getResult()); + } else { + // record the ptr as we have visited and built up the state for this scalar + // pointer, which may be used by rewriteForOp later. + ptrMap.map(op.getResult(), op.getResult()); + } + return success(); +} + +LogicalResult PtrAnalysis::rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (visitOperandMakeTensorPtr(op, state, op.getLoc(), builder).failed()) { + return failure(); + } + + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + knownPtrs[op.getResult()] = state; + ptrMap.map(op.getResult(), maketptrOp.getResult()); + return success(); +} + +LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) { + OpBuilder builder(op); + auto loc = op.getLoc(); + + PtrState state; + if (visitOperand(op->getOperand(0), state, loc, builder).failed()) { + op->emitRemark("PtrAnalysis: Failed to analyze ptr of tt.advance"); + return failure(); + } + assert(state.isBlockPtr() && + "tt.advance pointer state should describe a block pointer"); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (auto [increment, offset, stride] : + llvm::zip(incrementOffsets, state.offsets, state.strides)) { + Value offsetValue; + if (auto offsetIntAttr = getIntAttr(offset)) { + auto constOp = builder.create( + loc, builder.getIndexAttr(offsetIntAttr.value())); + offsetValue = constOp.getResult(); + } else { + offsetValue = cast(offset); + } + auto castOp = builder.create( + loc, builder.getIndexType(), increment); + auto mulOp = builder.create(loc, castOp.getResult(), + cast(stride)); + auto addOp = + builder.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + state.offsets = SmallVector(newOffsets); + + auto newOp = state.createTTSMakeTensorPtrOp(builder, loc); + knownPtrs[op.getResult()] = state; + ptrMap.map(op.getResult(), newOp.getResult()); + return success(); +} + +static bool isPointerType(Type t) { + if (auto tensor = llvm::dyn_cast(t)) { + return isa(tensor.getElementType()); + } + return isa(t); +} + +FailureOr PtrAnalysis::getLoopInitArgPtrState(scf::ForOp forOp, + size_t index) { + auto ptr = forOp.getInitArgs()[index]; + + // If the pointer into the scf.for was defined by tts.get_structured_state, + // we can get the pointer state from the original pointer (the op's input): + // + // %ptr, %offset_1, %offset_2,..., %stride_1, %stride_2,... = + // tts.get_structured_state %original + // scf.for ... (%ptr) {...} + if (auto getStateOp = ptr.getDefiningOp()) { + auto originalPtr = getStateOp->getOperand(0); + if (knownPtrs.count(originalPtr)) { + return knownPtrs[originalPtr]; + } + } + + // For nested loops scenarios, a pointer in init-args can be returned from + // another loop of the same level: + // e.g.: + // clang-format off + // %22:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + // %23 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %arg5) -> (tensor<2x2x!tt.ptr>) : i32 { + // %26 = tt.addptr %arg8, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + // scf.yield %26 : tensor<2x2x!tt.ptr> + // } + // %24:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %23, %arg9 = %arg6) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + // %26 = tt.load %arg8 : tensor<2x2x!tt.ptr> + // %27 = tt.addptr %arg8, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + // ... + // } + // ... + // } + // clang-format on + // Notice %arg8 = %23 comes from the return value of the first loop. + if (auto forOp = ptr.getDefiningOp()) { + return getLoopResultPtrState(forOp, index); + } + + // If the pointer isn't defined by tts.get_structured_state nor another loop, + // it means the current pointer is an iterarg of the outer loop. + // In such cases, the outer loops would have already set up the PtrState for + // us already. + // + // scf.for iterargs(%ptr = %init_arg) { + // scf.for iterargs(%ptr1 = %ptr) { <--- we're dealing with `%ptr1` here. + // ... + // } + // } + if (knownPtrs.count(ptr)) { + assert(!ptr.getDefiningOp() && "Expect the ptr to be an iterarg"); + return knownPtrs[ptr]; + } + + return failure(); +} + +PtrState PtrAnalysis::reconcileLoopPtrState( + scf::ForOp forOp, size_t iterArgIndex, const PtrState &state, + llvm::function_ref getReplacementVal) { + PtrState newState = state; + int cnt = iterArgIndex + 1; + if (newState.getRank() == 0) { + assert(newState.scalar); + // for scalar pointers, the scalar contains the offset and is the only + // relevant newState that could be updated by the loop. + newState.scalar = getReplacementVal(forOp, cnt); + } else { + for (auto &offset : newState.offsets) { + offset = getReplacementVal(forOp, cnt++); + } + + for (auto &stride : newState.strides) { + stride = getReplacementVal(forOp, cnt++); + } + } + + return newState; +} + +FailureOr PtrAnalysis::getLoopIterArgPtrState(scf::ForOp forOp, + size_t index) { + auto state = getLoopInitArgPtrState(forOp, index); + if (failed(state)) { + return failure(); + } + + return reconcileLoopPtrState( + forOp, index, state.value(), + [](scf::ForOp op, size_t index) { return op.getRegionIterArg(index); }); +} + +FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, + size_t index) { + auto state = getLoopInitArgPtrState(forOp, index); + if (failed(state)) { + return failure(); + } + + return reconcileLoopPtrState( + forOp, index, state.value(), + [](scf::ForOp op, size_t index) { return op->getResult(index); }); +} + +LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { + for (auto [i, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!maybeStructuredArgs.contains(arg)) { + continue; + } + + auto state = getLoopIterArgPtrState(op, i); + if (failed(state)) { + // Because the maybeStructuredArgs may contain values that are not + // considered structured by PtrAnalysis, failing to retrieve the PtrState + // should not fail the rewrite process. + // We emit an error for diagnostics and debugging purposes. + op->emitWarning( + "Rewrite for-op failed. Could not find PtrState for iter-arg index " + + std::to_string(i)); + continue; + } + + // Save the current init arg's PtrState + knownPtrs[arg] = state.value(); + + // For tensors of pointers, create a tts.make_tptr at the beginning of the + // loop body that correspond to this region iter arg. In case it is used + // by tt.load/tt.store in the loop body before pointer updates, this will + // make sure rewriteLoadOp/rewriteStoreOp can use the analysis result. + // E.g., given the following input (%tensor_of_ptr is a block arg): + // scf.for (%tensor_of_ptr) { + // %data = tt.load %tensor_of_ptr + // // more operations to update %tensor_of_ptr + // } + // We may produce the following output: + // scf.for (%base_ptr, %stride, %offset) { + // %tensor_of_ptr = tts.make_tptr(%base_ptr, %stride, %offset) + // %data = tts.load %tensor_of_ptr + // // more operations to update %offset + // } + // If %tensor_of_ptr is not used (i.e., %tensor_of_ptr is updated before + // used in the original IR), it will simply be removed by + // canonicalization. + + // For scalar pointers, there is no need to create a tts.addptr at the + // beginning of the loop body. We don't lower tt.load and tt.store on + // scalars in this pass; pointer arithmetics can also just use the + // original pointer. + // Note that there can be tensor of indices in iter-arg, so we only create + // the make_tensor_ptr op when the arg is of pointer type. + if (isPointerType(arg.getType())) { + if (state->getRank() != 0) { + OpBuilder builder(op.getRegion()); + auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(arg, maketptrOp.getResult()); + } + } + } + + // Recursively rewrite the inner ops + if (rewriteOp(op).failed()) { + op->emitRemark( + "PtrAnalysis: update loop body failed when rewriting for op"); + return failure(); + } + + return success(); +} + +LogicalResult +PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { + auto tritonValue = op->getOperand(0); + + // If this triton value isn't known, it means PtrAnalysis has failed to + // analyze this pointer. In such cases, simply remap all uses of the + // structured value back to its original triton value. + if (!knownPtrs.contains(tritonValue)) { + op.emitRemark( + "Rewrite GetStructuredStateOp failed. Could not find PtrState."); + op.getResult(0).replaceAllUsesWith(tritonValue); + return failure(); + } + + tts::PtrState state = knownPtrs[tritonValue]; + Value remappedValue = + ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue; + + SmallVector replacements{remappedValue}; + OpBuilder builder(op); + + if (state.getRank() == 0) { + // For scalar pointers, the scalar contains the offset and is the only + // relevant state that could be updated by the loop. + if (state.scalar) { + replacements.push_back(state.scalar); + } else { + // This operand is a pointer directly from the kernel arguments. + // Use offset 0. + assert(!tritonValue.getDefiningOp()); + replacements.push_back(builder.create( + op.getLoc(), builder.getIndexAttr(0))); + } + } else { + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + replacements.push_back(constOp.getResult()); + } else { + replacements.push_back(cast(s)); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + replacements.push_back(constOp.getResult()); + } else { + replacements.push_back(cast(s)); + } + } + } + + op->replaceAllUsesWith(replacements); + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, + bool useUnsafeMask) { + auto ptr = ptrMap.lookupOrNull(op.getPtr()); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "loadOp cannot be rewritten"); + return failure(); + } + + auto ptrType = dyn_cast(ptr.getType()); + if (ptrType && !isa(ptrType.getPointeeType())) { + op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + mlir::triton::MaskState mstate(useUnsafeMask); + Value scalarOther; + + OpBuilder builder(op); + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + if (other) { + assert(mask && "other value used while no masks are specified"); + + scalarOther = getScalarValue(other, loc, builder); + if (!scalarOther) { + op->emitRemark("other value used in masked load produced by " + "unsupported instruction"); + return failure(); + } + } + + auto loadOp = builder.create(loc, ptr, dims, scalarOther); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::load:\n"; + loadOp->dump(); + }); + + op.replaceAllUsesWith(loadOp.getResult()); + op->erase(); + return success(); +} + +// Structured values from the TritonStructuredDialect have offsets and strides +// that might change in each loop iteration and hence will appear in an scf.for +// iter-args like so: +// +// %structured, %offsets, %strides = tts.get_structured_state +// scf.for (%arg0 = %structured, %arg1 = %offsets, %arg2 = %strides) { +// %a = %arg0 + 1 +// %b = %b + 2 +// scf.for (%arg1 = %b) { +// ... +// } +// } +// +// In `rewriteForOp`, we have to recognize such structured values in order to +// rewrite their PtrState accordingly. Previously, only values of Pointer-like +// type (e.g.: tensor> or tt.ptr>), so detecting these values +// is as easy as checking the type. +// +// Now, tensor of indices could also appear in a loop's iter-arg. To reliably +// detect all such cases, we perform a BFS-like traversal of the IR where the +// sources are the results of `tts.get_structured_state`. All values that +// originate from the results of `tts.get_structured_state` are consider +// "maybeStructured". If a loop's iter-arg is considered "maybeStructured", we +// must set up their PtrState during `rewriteForOp`. +void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) { + std::queue q; + DenseSet visited; + + op->walk([&q, &visited](tts::GetStructuredStateOp getStateOp) { + Value value = getStateOp->getResult(0); + visited.insert(value); + q.push(value); + }); + + while (!q.empty()) { + auto v = q.front(); + q.pop(); + for (auto user : v.getUsers()) { + // scf.for is a special case. We have 2 set of values to consider: + // - iter-args + // - loop results + // for every init arg that originates from a `tts.get_structured_state` + // op, its corresponding iter-arg and loop result will also be considered + // "maybeStructured". + if (auto forOp = dyn_cast(user)) { + auto it = llvm::find(forOp.getInitArgs(), v); + + if (it == forOp.getInitArgs().end()) { + continue; + } + + auto argIndex = std::distance(forOp.getInitArgs().begin(), it); + auto iterArg = forOp.getRegionIterArg(argIndex); + auto tiedLoopRes = forOp.getTiedLoopResult(iterArg); + + SmallVector neighbors{iterArg, tiedLoopRes}; + for (auto neighbor : neighbors) { + maybeStructuredArgs.insert(neighbor); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + q.push(neighbor); + } + } + + } else { + for (auto res : user->getResults()) { + if (res.getType() != v.getType()) { + continue; + } + maybeStructuredArgs.insert(res); + if (!visited.contains(res)) { + visited.insert(res); + q.push(res); + } + } + } + } + } +} + +LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, + bool useUnsafeMask) { + auto ptr = ptrMap.lookupOrNull(op.getPtr()); + auto val = op.getValue(); + auto mask = op.getMask(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "storeOp cannot be rewritten"); + return failure(); + } + + auto ptrType = dyn_cast(ptr.getType()); + if (ptrType && !isa(ptrType.getPointeeType())) { + op->emitRemark("PtrAnalysis: scalar storeOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + mlir::triton::MaskState mstate(useUnsafeMask); + + OpBuilder builder(op); + + // Analyze the mask operand to determine at runtime the size of the data + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + auto storeOp = builder.create(loc, ptr, val, dims); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::store:\n"; + storeOp->dump(); + }); + + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { + LLVM_DEBUG({ + llvm::dbgs() << "rewriting rootOp\n"; + rootOp->dump(); + }); + + rootOp->walk([&](Operation *op) { + if (op == rootOp) { + return WalkResult::advance(); + } + return TypeSwitch(op) + .Case([&](auto addptr) { + if (rewriteAddptrOp(addptr).failed()) { + addptr->emitRemark("PtrAnalysis: Failed to rewrite AddPtrOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto maketptr) { + if (rewriteMakeTensorPtrOp(maketptr).failed()) { + maketptr->emitRemark( + "PtrAnalysis: Failed to rewrite MakeTensorPtrOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto advance) { + if (rewriteAdvanceOp(advance).failed()) { + advance->emitRemark("PtrAnalysis: Failed to rewrite AdvanceOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto load) { + if (rewriteLoadOp(load, useUnsafeMask).failed()) { + load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto store) { + if (rewriteStoreOp(store, useUnsafeMask).failed()) { + store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto forOp) { + // `rewriteForOp` recursively visits its children, so regardless + // whether the rewrite succeeds or not, we need to return "skip" so + // that the the walk does not visit the for-op's child operations + // the second time. + if (rewriteForOp(forOp).failed()) { + forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); + } + return WalkResult::skip(); + }) + .Case( + [&](tts::GetStructuredStateOp getStateOp) { + // For tensor of indices potentially being used in pointer + // arithmetic sequence, we need to manually populate the state of + // none already exists. + // This process is necessary because unlike triton pointers in a + // loop which always have a `tt.addptr` that triggers the rewrite + // process which includes generating the ops for updating offsets + // and strides, tensor of indices only have a simple `arith.addi` + // (or other arith ops). + // Without visiting these ops manually, the ops to update the + // offsets and strides would not be generated. + auto tritonValue = getStateOp->getOperand(0); + if (!knownPtrs.contains(tritonValue)) { + PtrState state; + OpBuilder b(getStateOp); + if (succeeded(visitOperand(tritonValue, state, + getStateOp->getLoc(), b))) { + knownPtrs[tritonValue] = state; + } else { + getStateOp->emitRemark("PtrAnalysis: Failed to populate ptr " + "state for tensor of indices"); + } + } + + return WalkResult::skip(); + }) + .Default([&](auto) { return WalkResult::advance(); }); + }); + + return success(); +} + +} // namespace tts +} // namespace mlir diff --git a/third_party/tsingmicro/lib/CMakeLists.txt b/third_party/tsingmicro/lib/CMakeLists.txt new file mode 100644 index 000000000..eff85b208 --- /dev/null +++ b/third_party/tsingmicro/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Analysis) +add_subdirectory(AnalysisStructured) +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/lib/Conversion/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..6177901f5 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt @@ -0,0 +1,10 @@ +add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) +add_subdirectory(TritonArithToLinalg) +add_subdirectory(StructuredToMemref) +add_subdirectory(Tx81MemrefToLLVM) +add_subdirectory(LinalgToMK) +add_subdirectory(MKToTx81) +add_subdirectory(Tx81ToLLVM) +add_subdirectory(TritonToCoreDialects) +add_subdirectory(CoreDialectsToMK) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt new file mode 100644 index 000000000..c4d27c2b3 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt @@ -0,0 +1,23 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +# All rights reserved. +# +#===------------------------------------------------------------------------===# + +add_triton_library(CoreDialectsToMK + CoreDialectsToMKPass.cpp + + DEPENDS + CoreDialectsToMKConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport +) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp new file mode 100644 index 000000000..a9ffab977 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp @@ -0,0 +1,60 @@ +//===------------------- CoreDialectsToMKPass.cpp -------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering core dialects to backend dialects +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h" +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h.inc" + +namespace { + +class CoreDialectsToMKPass : public CoreDialectsToMKBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + PassManager pm(&getContext(), moduleOp.getOperationName()); + + pm.addPass(createLinalgToMKPass()); + + // Erase dead code and fold constants created during lowering + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> triton::createCoreDialectsToMKPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt new file mode 100644 index 000000000..409f42d88 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(LinalgToMagicKernel + LinalgToMK.cpp + LinalgToMKPass.cpp + + DEPENDS + MagicKernelTableGen + LinalgToMKConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp new file mode 100644 index 000000000..a3970cf81 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -0,0 +1,56 @@ +//===------------------- LinalgToMK.cpp -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" + +#define DEBUG_TYPE "linalg-to-mk" + +using namespace mlir; +using namespace mk; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + +namespace { + +// Convert tensor.empty + linalg.fill + linalg.matmul to mk.matmul +struct MatmulConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); +#if 0 + auto tensorType = *op->getResultTypes().begin(); + + Value output = op->getResult(0); + auto fillOp = output.getDefiningOp(); + Value emptyTensor = fillOp->getResult(0); + auto tensorEmptyOp = emptyTensor.getDefiningOp(); + + auto dotOp = rewriter.create(loc, tensorType, op->getOperand(0), + op->getOperand(1), + op.getNumOperands() == 3 ? op->getOperand(2) : nullptr); + rewriter.replaceOp(op, dotOp); +#endif + return success(); + } +}; + +} // namespace + +void mlir::triton::populateLinalgToMKCanonicalizationPatterns( + RewritePatternSet &patterns) {} + +void mlir::triton::populateLinalgToMKConversionPatterns( + RewritePatternSet &patterns) { + // patterns.add(patterns.getContext()); +} diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp new file mode 100644 index 000000000..eaba1f34f --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp @@ -0,0 +1,70 @@ +//===------------------- LinalgToMKPass.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "linalg-to-mk" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_LINALGTOMK +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class LinalgToMKPass : public triton::impl::LinalgToMKBase { + using LinalgToMKBase::LinalgToMKBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + // TODO: Enable this when all conversion pattern has been implemented. + // target.addIllegalDialect(); + + target.addLegalDialect(); + + target.addLegalOp(); + + triton::populateLinalgToMKConversionPatterns(patterns); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> triton::createLinalgToMKPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt new file mode 100644 index 000000000..b7e951a04 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(MKToTx81 + MKToTx81.cpp + MKToTx81Pass.cpp + + DEPENDS + Tx81TableGen + MKToTx81ConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp new file mode 100644 index 000000000..3b51faf43 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -0,0 +1,949 @@ +//===--------------------- MKToTx81.cpp -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements the patterns to convert operations from mk dialect to +// tx81 dialect. It converts memory operations to RdmaOp/WdmaOp and converts +// mk.dot to tx.gemm etc. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" +#include "Tx81/tx81.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "mk-to-tx81" + +using namespace mlir; +using namespace tx; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Type Conversion +//===----------------------------------------------------------------------===// + +class MKToTx81TypeConverter : public TypeConverter { +public: + MKToTx81TypeConverter() { + // Add conversions for MemRef types to UI64 (representing SPM addresses) + addConversion([](MemRefType type) -> Type { + return IntegerType::get(type.getContext(), 64, IntegerType::Unsigned); + }); + + // Add conversions for Tensor types to UI64 (representing SPM addresses) + addConversion([](TensorType type) -> Type { + return IntegerType::get(type.getContext(), 64, IntegerType::Unsigned); + }); + + // Keep other types as is + addConversion([](Type type) -> Type { return type; }); + } + +private: + MLIRContext *context; +}; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Get format code for tensor element type +// This maps MLIR types to Tx81 format codes +Data_Format getFormatCode(MemRefType type) { + auto elemType = type.getElementType(); + if (elemType.isF32()) { + return Fmt_FP32; + } else if (elemType.isF16()) { + return Fmt_FP16; + } else if (elemType.isBF16()) { + return Fmt_BF16; + } else if (elemType.isInteger(8)) { + return Fmt_INT8; + } else { + llvm_unreachable("Tx8 unsupported the element type\n"); + } + // Default to F32 format + return Fmt_FP32; +} + +// Helper function to extract shape from tensor type +SmallVector getShapeFromTensorType(TensorType type) { + SmallVector shape; + for (auto dim : type.getShape()) + shape.push_back(static_cast(dim)); + return shape; +} + +// Helper function to extract dimensions from memref or tensor type +SmallVector getDimsFromType(Type type) { + SmallVector dims; + if (auto memrefType = mlir::dyn_cast(type)) { + for (auto dim : memrefType.getShape()) + dims.push_back(static_cast(dim)); + } else if (auto tensorType = mlir::dyn_cast(type)) { + for (auto dim : tensorType.getShape()) + dims.push_back(static_cast(dim)); + } + return dims; +} + +static uint64_t getElemByte(Type type) { + static DataLayout dataLayout; + auto typeSize = dataLayout.getTypeSize(type); + if (!typeSize.isFixed()) { + llvm::llvm_unreachable_internal("All element type should have fixed size."); + } + return typeSize.getFixedValue(); +} + +static Value createAddressFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value memref) { + auto stridedMetadata = + rewriter.create(loc, memref); + Value indexBasePtr = rewriter.create( + loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + auto elemType = mlir::cast(memref.getType()).getElementType(); + Value elemByte = + rewriter.create(loc, getElemByte(elemType)); + Value offset = stridedMetadata.getOffset(); + Value byteOffset = + rewriter.create(loc, offset.getType(), offset, elemByte); + Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), + indexBasePtr, byteOffset); + Value i64SPMPtr = rewriter.create( + loc, rewriter.getI64Type(), offsetPtr); + return i64SPMPtr; +} + +static std::tuple +createMetadata(ConversionPatternRewriter &rewriter, Location loc, + Value operand) { + auto stridedMetadata = + rewriter.create(loc, operand); + Value indexBasePtr = rewriter.create( + loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + auto elemType = mlir::cast(operand.getType()).getElementType(); + Value elemByte = + rewriter.create(loc, getElemByte(elemType)); + Value offset = stridedMetadata.getOffset(); + Value byteOffset = + rewriter.create(loc, offset.getType(), offset, elemByte); + Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), + indexBasePtr, byteOffset); + Value i64SPMPtr = rewriter.create( + loc, rewriter.getI64Type(), offsetPtr); + + // FIXME: For multi-dimensional(rank > 2), strides need to be multiplied. + return {i64SPMPtr, stridedMetadata.getSizes(), stridedMetadata.getStrides()}; +} + +static SmallVector padSizesToNHWC(ConversionPatternRewriter &rewriter, + Location loc, ValueRange sizes) { + Value one = rewriter.create(loc, 1); + int numPad = 4 - sizes.size(); + SmallVector nhwcShape; + while (numPad--) { + nhwcShape.push_back(one); + } + for (auto dim : sizes) { + nhwcShape.push_back(dim); + } + return nhwcShape; +} + +// The last stride is always 1, skip it, nhwcStrides.size() will be 3. +static SmallVector +padStridesToNHWC(ConversionPatternRewriter &rewriter, Location loc, + ValueRange strides) { + Value one = rewriter.create(loc, 1); + int numPad = 4 - strides.size(); + SmallVector nhwcStrides; + while (numPad--) { + nhwcStrides.push_back(one); + } + for (auto dim : strides) { + nhwcStrides.push_back(dim); + } + nhwcStrides.pop_back(); + return nhwcStrides; +} + +static Value calculateElemCount(ConversionPatternRewriter &rewriter, + Location loc, ValueRange sizes) { + // If we get scalar data, sizes is empty, return 1 + if (sizes.empty()) { + return rewriter.create(loc, 1); + } + + Value elemCount = sizes[0]; + for (int i = 1; i < sizes.size(); i++) { + elemCount = rewriter.create(loc, elemCount.getType(), + elemCount, sizes[i]); + } + return elemCount; +} + +// Extract the operations from a linalg op region +template llvm::SmallVector getRegionOps(T linalgOp) { + auto regionBlock = linalgOp.getBody(); + return llvm::map_to_vector(regionBlock->without_terminator(), + [](Operation &op) { return &op; }); +} + +// Convert integer type to float type for CGRA instruction +// Return the convert float type format code +// TODO: Directly convert memref type? +Data_Format insertConvertTypeOp(Value valuePtr, MemRefType valueType, + Value elemCount, + ConversionPatternRewriter &rewriter, + Location loc) { + + // TODO: Other integer type. May need realloc the memory + auto elemType = valueType.getElementType(); + + if (!isa(elemType)) + return getFormatCode(valueType); + + Data_Format fmt = Fmt_FP32; + // Get the bit width from the element type + auto bitWidth = elemType.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 16: { // 16 bit integer + rewriter.create(loc, rewriter.getI64Type(), valuePtr, + valuePtr, elemCount); + fmt = Fmt_FP16; + break; + } + case 32: { // 32 bit integer + rewriter.create(loc, rewriter.getI64Type(), valuePtr, + valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + default: { + llvm_unreachable("Unsupported integer type\n"); + } + } + return fmt; +} + +// Restore float type to integer type to for CGRA instruction +Value insertRestoreTypeOp(Value valuePtr, MemRefType valueType, Value elemCount, + ConversionPatternRewriter &rewriter, Location loc) { + // TODO: Other integer type. May need realloc the memory + auto elemType = valueType.getElementType(); + auto newValue = valuePtr; + if (!isa(elemType)) + return newValue; + + // Get the bit width from the element type + auto bitWidth = elemType.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 16: { // 16 bit integer + newValue = rewriter.create( + loc, rewriter.getI64Type(), valuePtr, valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + case 32: { // 32 bit integer + newValue = rewriter.create( + loc, rewriter.getI64Type(), valuePtr, valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + default: { + llvm_unreachable("Unsupported integer type\n"); + } + } + return newValue; +} + +class MemoryCopyConvertPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + // Workaround: Avoid analyzing control flow as much as possible + bool isOperandMemorySpaceSPM(Value operand) const { + + while (auto op = operand.getDefiningOp()) { + if (isa(op)) + return true; + operand = op->getOperand(0); + } + return false; + } + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(op->hasAttr("srcSpm") && op->hasAttr("dstSpm") && + "Can't get memory space attribute\n"); + bool isSrcSPM = mlir::cast(op->getAttr("srcSpm")).getInt(); + bool isDstSPM = mlir::cast(op->getAttr("dstSpm")).getInt(); + + // DDR to DDR + if (!isSrcSPM && !isDstSPM) + return rewriter.notifyMatchFailure( + op, "Can not copy memory from DDR to DDR.\n"); + + auto [srcPtr, srcSizes, srcStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getSource()); + auto [dstPtr, dstSizes, dstStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getTarget()); + + auto inputType = dyn_cast(op.getSource().getType()); + // SPM to SPM + if (isSrcSPM && isDstSPM) { + // FIXME: Only support 1d for now, take sizes[0] as elemCount. + auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); + + // WORKAROUND: Assume no mask. + auto constValue = rewriter.create( + op.getLoc(), 0, rewriter.getI32Type()); + + rewriter.create( + op->getLoc(), rewriter.getI64Type(), srcPtr, constValue, dstPtr, + elemCount, // Element count + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format + ); + } else if (isDstSPM) { + auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), srcSizes); + auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), srcStrides); + + auto rdmaOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format + ); + } else { + auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), dstSizes); + auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), dstSizes); + + auto wdmaOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format + ); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +// Convert linalg.fill to MemsetOp +class LinalgFillOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(linalg::FillOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the value to fill with + Value fillValue = op.getInputs()[0]; // adaptor.getValue(); + + if (op.getOutputs().size() != 1) + return rewriter.notifyMatchFailure(op, "Only support single output\n"); + + auto [srcPtr, srcSizes, srcStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto inputType = op.getInputs()[0].getType(); + auto bitWidth = op.getInputs()[0].getType().getIntOrFloatBitWidth(); + assert(bitWidth == 16 || + bitWidth == 32 && "Only support 16/32 fill value\n"); + + // AddVS value need has fmt with input fmt and only support float type + Data_Format fmt = bitWidth == 16 ? Fmt_FP16 : Fmt_FP32; + + if (inputType.isInteger()) { + auto floatType = + bitWidth == 16 ? rewriter.getF16Type() : rewriter.getF32Type(); + fillValue = + rewriter.create(op.getLoc(), floatType, fillValue); + } + + auto bitcastType = + bitWidth == 16 ? rewriter.getI16Type() : rewriter.getI32Type(); + fillValue = + rewriter.create(op.getLoc(), bitcastType, fillValue); + + if (bitWidth == 16) { + fillValue = rewriter.create( + op.getLoc(), rewriter.getI32Type(), fillValue); + } + + // TODO: For scalar data, instead of function call, we should convert + // linalg.fill to memref.store directly to get better performance. + + // Use xor + addvs to simulate memset operation. Only support type fp32 and + // fp16 + // 1. xor srcPtr with itself to get zero + // 2. addvs srcPtr with value to get the fill value + auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); + + auto init = + rewriter.create(op.getLoc(), rewriter.getI64Type(), srcPtr, + srcPtr, srcPtr, elemCount, fmt); + auto resultOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, fillValue, srcPtr, + elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(fmt)); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// mk.dot to tx.gemm Conversion Pattern +//===----------------------------------------------------------------------===// + +class MKDotToTx81GemmOpConversion + : public OpConversionPattern { + + void fp32ToTF32(ConversionPatternRewriter &rewriter, Location loc, + ValueRange sizes, Value spmAddr) const { + // Warning for neural engine that fp32 is not supported + llvm::errs() + << "\nNeural engine not support FP32. Convert FP32 to TF32 for " + "tx.Gemm Op\n"; + auto elemCount = calculateElemCount(rewriter, loc, sizes); + rewriter.create( + loc, rewriter.getI64Type(), spmAddr, spmAddr, + elemCount, // element_count + rewriter.getI16IntegerAttr(0) // round_mode + ); + } + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::mk::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Extract dimensions from tensor types + MemRefType aTensorType = mlir::cast(op.getA().getType()); + MemRefType bTensorType = mlir::cast(op.getB().getType()); + assert(aTensorType.getElementType() == bTensorType.getElementType() && + "a and b must have the same element type"); + MemRefType zeroTensorType = + mlir::cast(op.getZeroes().getType()); + Data_Format srcFmt = getFormatCode(aTensorType); + Data_Format dstFmt = getFormatCode(zeroTensorType); + + // Get converted operands + auto loc = op.getLoc(); + + auto aShape = aTensorType.getShape(); + auto bShape = bTensorType.getShape(); + + // Matrix dimensions M, K, N for GEMM + int32_t M = aShape[0]; + int32_t K = aShape[1]; + int32_t N = bShape[1]; + + // Create dimensions array attribute [M, K, N] + auto dims = rewriter.getI32ArrayAttr({M, K, N}); + + // Get operand ptr + auto [aPtr, aSizes, aStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getA()); + auto [bPtr, bSizes, bStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getB()); + auto [cPtr, cSizes, cStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getC()); + // Assume input type is same. Tx neural engine not support fp32 for input + if (aTensorType.getElementType().isF32()) { + srcFmt = Data_Format::Fmt_TF32; + fp32ToTF32(rewriter, op->getLoc(), aSizes, aPtr); + fp32ToTF32(rewriter, op->getLoc(), bSizes, bPtr); + fp32ToTF32(rewriter, op->getLoc(), cSizes, cPtr); + } + + auto dst = createAddressFromMemref(rewriter, loc, adaptor.getZeroes()); + + auto zero = rewriter.create(op.getLoc(), 0, + rewriter.getI64Type()); + + // Create GemmOp + rewriter.create( + op.getLoc(), rewriter.getI64Type(), + aPtr, // src_a (Matrix A in SPM) + bPtr, // src_b (Matrix B in SPM) + cPtr, // src_bias (optional accumulation) + dst, // dst, + dims, // dimensions [M,K,N] + rewriter.getBoolAttr(false), // en_psum + dst, // WORKAROUND: psum_addr (using dst buffer) + rewriter.getBoolAttr(false), // trans_src_a + // NOTE: (N, K) is thought not trans in hardware + rewriter.getBoolAttr(true), // trans_src_b. + rewriter.getI32IntegerAttr(1), // batch_src_a + rewriter.getI32IntegerAttr(1), // batch_src_b + rewriter.getI32IntegerAttr(ActFuncMode::None), // relu_mode. + rewriter.getBoolAttr(op.getC() != nullptr), // en_bias + rewriter.getBoolAttr(false), // en_neg_scale + zero, // src_neg_scale + rewriter.getBoolAttr(false), // en_pos_scale + zero, // src_pos_scale + rewriter.getI32IntegerAttr(srcFmt), // src_fmt + rewriter.getI32IntegerAttr(dstFmt) // dst_fmt + ); + // Op has no result value + rewriter.eraseOp(op); + + return success(); + } +}; + +struct ElementwiseConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + template + LogicalResult convertUnaryOp(linalg::GenericOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adapter.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adapter.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + auto outputType = dyn_cast(op.getOutputs()[0].getType()); + // Data format after conversion + Data_Format srcFmt = + insertConvertTypeOp(input, inputType, elemCount, rewriter, loc); + Data_Format dstFmt = + insertConvertTypeOp(output, outputType, elemCount, rewriter, loc); + // Create the unary operation + rewriter.create(loc, rewriter.getI64Type(), input, output, elemCount, + rewriter.getI16IntegerAttr(srcFmt)); + insertRestoreTypeOp(input, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(output, outputType, elemCount, rewriter, loc); + + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult convertBinaryOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + // Data format after conversion + Data_Format srcFmt = + insertConvertTypeOp(input0, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(output, inputType, elemCount, rewriter, loc); + + // Create the elementwise operation + // TODO: Fix attribute + rewriter.create(loc, rewriter.getI64Type(), input0, input1, output, + elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(srcFmt)); + + insertRestoreTypeOp(input0, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(output, inputType, elemCount, rewriter, loc); + + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult NormalConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + rewriter.create(loc, rewriter.getI64Type(), input, output, + elemCount); + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult RoundConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + // TODO: Fix attribute + auto result = + rewriter.create(loc, + rewriter.getI64Type(), // Result type + input, // Input + output, // Output + elemCount, // Element count + rewriter.getI16IntegerAttr(0) // Round mode + ); + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult BoolRelationVVOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + // Create the elementwise operation + // TODO: Fix attribute + rewriter.create( + loc, rewriter.getI64Type(), input0, input1, output, elemCount, + rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format + ); + + rewriter.eraseOp(op); + return success(); + } + + LogicalResult FmaConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto input2 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[2]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + auto mulResult = rewriter.create( + loc, rewriter.getI64Type(), input0, input1, output, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType))); + auto addResult = rewriter.create( + loc, rewriter.getI64Type(), output, input2, output, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType))); + rewriter.eraseOp(op); + return success(); + } + + LogicalResult + matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto regionOps = getRegionOps(op); + + // Check if the operation is elementwise + if (op.getIteratorTypesArray().front() != utils::IteratorType::parallel) + return rewriter.notifyMatchFailure(op, "Only support elementwise op."); + + if (regionOps.size() != 1) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, + "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + } + + auto elemWiseOp = regionOps[0]; + auto resultType = elemWiseOp->getResult(0).getType(); + return llvm::TypeSwitch(elemWiseOp) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case( + [&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return NormalConvertOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return FmaConvertOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + // TODO: Need add more int to fp convert. + auto inputType = mlir::cast(op.getInputs()[0].getType()) + .getElementType(); + auto outputType = mlir::cast(op.getOutputs()[0].getType()) + .getElementType(); + if (inputType.isInteger(16) && outputType.isF32()) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(16) && outputType.isF16()) { + return NormalConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(32) && outputType.isF16()) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(32) && outputType.isF32()) { + return RoundConvertOp(op, adaptor, rewriter); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for integer to " + "FP conversion"); + } + }) + .Case([&](auto elemWiseOp) { + // TODO: Need add more int to fp convert. + auto inputType = mlir::cast(op.getInputs()[0].getType()) + .getElementType(); + auto outputType = mlir::cast(op.getOutputs()[0].getType()) + .getElementType(); + if (inputType.isF16() && outputType.isInteger(8)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF16() && outputType.isInteger(16)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF16() && outputType.isInteger(32)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(8)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(16)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(32)) { + return RoundConvertOp(op, adaptor, rewriter); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for fp to " + "integer conversion"); + } + }) +// FIXME: Now BoolLessThenOp run fail on board. Need more op information from +// Tx81 +#if 0 + .Case([&](auto elemWiseOp) { + arith::CmpIPredicate predicate = elemWiseOp.getPredicate(); + switch (predicate) { + case arith::CmpIPredicate::eq: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::ne: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::sge: + return BoolRelationVVOp(op, adaptor, + rewriter); + case arith::CmpIPredicate::sgt: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::sle: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::slt: + return BoolRelationVVOp(op, adaptor, rewriter); + default: + llvm_unreachable("Not yet supported"); + break; + } + }) +#endif + .Case([&](auto elemWiseOp) { + if (resultType.isF16()) + return RoundConvertOp(op, adaptor, rewriter); + else if (resultType.isBF16()) + return RoundConvertOp(op, adaptor, rewriter); + else + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for trunc " + "conversion"); + }) + .Default([&](auto elemWiseOp) { + // WORKAROUND: Used to handle tl.arange(0, BLOCK_SIZE) which will + // lower to linalg.generic + linalg.index + arith.index_cast and + // other unsupported case now (eg: arith::extf) + // TODO: Lower ops to tx81 if is supported + + // Affine dialect should handled before this pass. So here lower it to + // scf.for + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure( + op, "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + }); + } +}; + +struct ReduceConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + bool isReductionOpSupported(Operation *redOp) const { + return isa( + redOp); + } + + template + LogicalResult convertToReduceOp(linalg::ReduceOp op, + typename linalg::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto dims = op.getDimensions(); + if (dims.size() != 1) + return rewriter.notifyMatchFailure(op, "Only support one dim reduce."); + auto dim = dims[0]; + auto input = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInputs()[0]); + auto output = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInits()[0]); + auto inputType = dyn_cast(op.getInputs()[0].getType()); + auto inputShape = inputType.getShape(); + // TODO: Support any rank + if (inputShape.size() > 1) + return rewriter.notifyMatchFailure(op, "Rank > 1 unsupported yet."); + + if (dim && dim >= inputShape.size()) + return rewriter.notifyMatchFailure(op, + "Dimensions attribute > input rank !"); + + int64_t inputSize = inputShape.empty() ? 1 : inputShape[0]; + + SmallVector reduceShape = {1, 1, 1, inputSize}; + auto format = getFormatCode(inputType); + auto reduceOp = rewriter.create( + op->getLoc(), TypeRange{}, input, output, + rewriter.getUI32IntegerAttr(dim), rewriter.getI64ArrayAttr(reduceShape), + rewriter.getI16IntegerAttr(format)); + rewriter.replaceOp(op, reduceOp); + return success(); + } + +public: + LogicalResult + matchAndRewrite(linalg::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto reductionOps = getRegionOps(op); + + if (reductionOps.size() != 1 || + !isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with body " + "containing 1 max(i/f) or addf."); + } + auto redOp = reductionOps[0]; + + return llvm::TypeSwitch(redOp) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return failure(); + }); + } +}; + +} // namespace + +void mlir::triton::populateMKToTx81CanonicalizationPatterns( + RewritePatternSet &patterns) {} + +void mlir::triton::populateMKToTx81ConversionPatterns( + RewritePatternSet &patterns) { + + MKToTx81TypeConverter typeConverter; + + // Add type conversion patterns + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + // clang-format off + patterns.add( + patterns.getContext()); + // clang-format on +} diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp new file mode 100644 index 000000000..371c2faeb --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp @@ -0,0 +1,139 @@ +//===--------------------- MKToTx81Pass.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "mk-to-tx81" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MKTOTX81 +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class MKToTx81Pass : public triton::impl::MKToTx81Base { + using MKToTx81Base::MKToTx81Base; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + bool isOperandMemorySpaceSPM(Value operand) { + Operation *lastOp = operand.getDefiningOp(); + Operation *op = lastOp; + + do { + if (isa(op)) + return true; + else if (auto forOp = dyn_cast(op)) { + // Here we assume that yieldResults (inner loop region) and + // loopResults (outer loop region) correspond one-to-one to obtain the + // inner loop region definingOp of the outer loop region value. + // FIXME: Need reference the standard loop analysis to refactor this. + + auto yieldResults = forOp.getYieldedValues(); + mlir::ResultRange loopResults = forOp.getLoopResults().value(); + assert(yieldResults.size() == loopResults.size()); + + auto idx = std::distance( + loopResults.begin(), + std::find(loopResults.begin(), loopResults.end(), operand)); + operand = yieldResults[idx]; + + } else { + operand = op->getOperand(0); + } + lastOp = op; + op = operand.getDefiningOp(); + } while (op); + return false; + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Use to memory::CopyOp to tx dialect op + moduleOp->walk([&](Operation *op) { + if (isa(op)) { + auto copyOp = cast(op); + op->setAttr("srcSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isOperandMemorySpaceSPM( + copyOp.getSource())))); + op->setAttr("dstSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isOperandMemorySpaceSPM( + copyOp.getTarget())))); + } + }); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + // Register illegal ops for Dialect Conversion + target.addIllegalDialect(); + + target.addLegalDialect(); + + target.addIllegalOp(); + target.addLegalOp(); + + triton::populateMKToTx81ConversionPatterns(patterns); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + // linalg::linalgOpToLoops will generate memref::LoadOp/memref::StoreOp + // before and after the arith calculation. + // Use to check whether add spm mapping offset in + // memref::LoadOp/memref::StoreOp lowering + moduleOp->walk([&](Operation *op) { + if (isa(op)) { + bool isSpm = isa(op) + ? isOperandMemorySpaceSPM(op->getOperand(0)) + : isOperandMemorySpaceSPM(op->getOperand(1)); + + op->setAttr("isSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isSpm))); + } + }); + } +}; + +} // namespace + +std::unique_ptr> triton::createMKToTx81Pass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt new file mode 100644 index 000000000..0883bca4f --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(StructuredToMemref + StructuredToMemref.cpp + StructuredToMemrefPass.cpp + + DEPENDS + StructuredToMemrefConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRSCFTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp new file mode 100644 index 000000000..b5e1165a7 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -0,0 +1,859 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR//MemRef.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include +#include +#include + +#define DEBUG_TYPE "structured-to-memref" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +static const std::string WRAP_SIDE_BY_SIDE = "wrap_side_by_side"; +static const std::string WRAP_STACKED = "wrap_stacked"; + +static memref::SubViewOp getSubview(int rank, ArrayRef dims, + Value source, Location loc, OpBuilder &b) { + auto sourceType = cast(source.getType()); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector strides(rank, b.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + + return b.create(loc, cast(dstType), source, + offsets, dims, strides); +} + +namespace { + +struct MakeTensorPtrConverter + : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + static Type getElementTypeStructuredPtr(tts::MakeTensorPtrOp op) { + assert(!op.isBlockPtr()); + // tensor<1024x!tt.ptr> + auto ptrType = cast( + cast(op.getType()).getElementType()); + return ptrType.getPointeeType(); + } + + static Type getElementTypeBlockPtr(tts::MakeTensorPtrOp op) { + assert(op.isBlockPtr()); + // !tt.ptr, 1> + auto shapedType = cast( + cast(op.getType()).getPointeeType()); + return shapedType.getElementType(); + } + + static MemRefType getResultMemrefType(tts::MakeTensorPtrOp op, int64_t offset, + ArrayRef staticStrides, + ArrayRef resultShape) { + auto layout = + StridedLayoutAttr::get(op.getContext(), offset, staticStrides); + Type elemType; + if (op.isBlockPtr()) { + elemType = getElementTypeBlockPtr(op); + } else { + elemType = getElementTypeStructuredPtr(op); + } + return MemRefType::get(resultShape, elemType, layout); + } + + // If there are dimensions with size 1 and stride 0, replace 0 stride with + // the product of sizes of all lower dimensions. This avoids creating memref + // with zero stride. + static llvm::SmallVector + getMixedStridesForMemref(tts::MakeTensorPtrOp op, OpBuilder &b) { + llvm::SmallVector strides; + auto accumulate = 1; + for (auto [size, stride] : + llvm::reverse(llvm::zip(op.getSizes(), op.getMixedStrides()))) { + auto strideIntAttr = getIntAttr(stride); + if (size == 1 && strideIntAttr && strideIntAttr.value() == 0) { + strides.push_back(b.getIndexAttr(accumulate)); + } else { + strides.push_back(stride); + } + accumulate *= size; + } + std::reverse(strides.begin(), strides.end()); + return strides; + } + + static OpFoldResult accumulateTargetOffset(tts::MakeTensorPtrOp op, + OpBuilder &b) { + Location loc = op->getLoc(); + OpFoldResult targetOffset = b.getIndexAttr(0); + for (auto o : op.getMixedOffsets()) { + targetOffset = addOFRs(targetOffset, o, loc, b); + } + return targetOffset; + } + + std::pair + createSideBySideCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto resultShape = cast(op.getType()).getShape(); + + auto targetOffset = + ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Note: We do not support cases where the target has already overflown the + // number of columns! This is because in PtrAnalysis, the offset has already + // been collapsed into a single dimension, so it is ambiguous to determine + // whether the offset actually overflows or just refers to an element on the + // subsequent rows. + // + // Same limitations apply to the stacked wraparound case. + // + //////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // nextOffset = x + colSize + // clampedOffset = min(nextOffset, N) + // d1 = clampedOffset - x + // + //////////////////////////////////////////////////////////////////////////// + + auto resultType = getResultMemrefType( + op, /* offset */ ShapedType::kDynamic, + /* staticStrides */ + SmallVector(resultShape.size(), ShapedType::kDynamic), + /* result shape */ + SmallVector{ + + // Row stays the same + resultShape[0], + + // Column is dynamic, in most cases, this + // should be the same as the original column. + // The last chunk may be smaller due to + // wrapping around. + ShapedType::kDynamic}); + + Value rowSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[1])); + + Value modN = ofrToIndexValue(op.getMixedShape()[1], loc, rewriter); + + Value x = rewriter.create(loc, targetOffset, modN); + Value y = rewriter.create(loc, targetOffset, x); + + SmallVector strideVals = + ofrsToIndexValues(op.getMixedStrides(), loc, rewriter); + + // First chunk + Value nextOffset = rewriter.create(loc, x, colSize); + Value clampedOffset = + rewriter.create(loc, nextOffset, modN); + Value d1 = rewriter.create(loc, clampedOffset, x); + SmallVector sizes1{rowSize, d1}; + + auto cast1 = rewriter.create( + loc, resultType, adaptor.getBase(), targetOffset, sizes1, strideVals); + + // Second chunk + Value d2 = rewriter.create(loc, colSize, d1); + SmallVector sizes2{rowSize, d2}; + + auto cast2 = rewriter.create( + loc, resultType, adaptor.getBase(), y, sizes2, strideVals); + + return {cast1, cast2}; + } + + std::pair + createStackedCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto loc = op->getLoc(); + auto resultShape = cast(op.getType()).getShape(); + + assert(resultShape.size() == 2); + + auto targetOffset = + ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // + // We do not support cases where the target offset has already overflown the + // number of rows. See side-by-side wraparound for details. + // + //////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // ~~~~~~~~~~~~~~~~~ + // ^ + // | + // We have already computed + // rows * strideRows = modRow = shape[1] + // in TritonToStructured + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + + auto resultType = getResultMemrefType( + op, /* offset */ ShapedType::kDynamic, + /* staticStrides */ + SmallVector(resultShape.size(), ShapedType::kDynamic), + /* result shape */ + SmallVector{ + // Row is dynamic, in most cases, this should + // be the same as the original row. The last + // chunk may be smaller due to wrapping + // around. + ShapedType::kDynamic, + + // Col stays the same. + resultShape[1], + }); + + Value rowSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[1])); + + Value strideRow = ofrToIndexValue(op.getMixedStrides()[0], loc, rewriter); + Value strideCol = ofrToIndexValue(op.getMixedStrides()[1], loc, rewriter); + + Value modRow = op.getShape()[0]; + + // First chunk + Value wrappedAroundOff = + rewriter.create(loc, targetOffset, strideRow); + Value clampedOff = + rewriter.create(loc, modRow, wrappedAroundOff); + Value d1 = rewriter.create(loc, clampedOff, targetOffset); + d1 = rewriter.create(loc, d1, strideRow); + + SmallVector sizes1{d1, colSize}; + memref::ReinterpretCastOp cast1 = + rewriter.create( + loc, resultType, adaptor.getBase(), targetOffset, sizes1, + ValueRange{strideRow, strideCol}); + + // Second chunk + Value d2 = rewriter.create(loc, rowSize, d1); + SmallVector sizes2{d2, colSize}; + memref::ReinterpretCastOp cast2 = + rewriter.create( + loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); + + return {cast1, cast2}; + } + + LogicalResult rewriteSplitPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto parentShape = op.getStaticShape(); + + SmallVector casts; + StringRef wrapType; + + if (parentShape[0] == ShapedType::kDynamic) { + // Stacked case + assert(parentShape[1] == 0); + auto [cast1, cast2] = createStackedCastOps(op, adaptor, rewriter); + casts = {cast1.getResult(), cast2.getResult()}; + wrapType = WRAP_STACKED; + } else { + assert(parentShape[0] == 0); + auto [cast1, cast2] = createSideBySideCastOps(op, adaptor, rewriter); + casts = {cast1.getResult(), cast2.getResult()}; + wrapType = WRAP_SIDE_BY_SIDE; + } + + auto combinedCast = rewriter.create( + op.getLoc(), op.getType(), casts); + + combinedCast->setAttr(wrapType, rewriter.getUnitAttr()); + + rewriter.replaceOp(op, combinedCast); + + return success(); + } + + LogicalResult rewritePtr(ArrayRef resultShape, bool isBlockPtr, + tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto mixedStrides = getMixedStridesForMemref(op, rewriter); + SmallVector staticStrides; + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides); + + auto targetOffset = accumulateTargetOffset(op, rewriter); + auto staticTargetOffset = getIntAttr(targetOffset); + auto resultType = getResultMemrefType( + op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides, + resultShape); + + // The base ptr, which is from one of the args, would have already been + // converted to memref<*> at this point, so get the base from adaptor. + // + // For block pointers, the base could come from a sequence of `tt.addptr`, + // which at this point has already been lowered to a sequence of + // `memref.reinterpret_cast` ops. The offset in such cases are dynamic. + // (see test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir) + // + // For non-block pointer cases, the base is the reinterpret_cast of a + // function argument. Assert that the offset is a constant 0 in such cases. + auto ptr = adaptor.getBase(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto offset = reinterpretCast.getMixedOffsets()[0]; + auto intAttr = getIntAttr(offset); + assert(isBlockPtr || (intAttr.has_value() && intAttr.value() == 0)); + targetOffset = addOFRs(targetOffset, reinterpretCast.getMixedOffsets()[0], + op->getLoc(), rewriter); + } + + auto castOp = rewriter.create( + op.getLoc(), resultType, ptr, targetOffset, op.getMixedSizes(), + mixedStrides); + + rewriter.replaceOp(op, castOp); + + return success(); + } + + LogicalResult + rewriteStructuredPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ArrayRef resultShape = cast(op.getType()).getShape(); + return rewritePtr(resultShape, false, op, adaptor, rewriter); + } + + LogicalResult rewriteBlockPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Block pointers are basically the same as structured pointers except that + // the return types are !tt.ptr> instead of + // tensor> + ArrayRef resultShape = + cast( + cast(op.getType()).getPointeeType()) + .getShape(); + return rewritePtr(resultShape, true, op, adaptor, rewriter); + } + +public: + LogicalResult + matchAndRewrite(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!llvm::is_sorted(op.getOrder(), std::greater<>())) { + emitError(op.getLoc()) << "non-decreasing dimension order on tensor " + "pointers are not yet supported"; + return failure(); + } + + if (op.isBlockPtr()) { + return rewriteBlockPtr(op, adaptor, rewriter); + } + + if (op.isStructuredPtr()) { + return rewriteStructuredPtr(op, adaptor, rewriter); + } + + if (op.isSplitPtr()) { + return rewriteSplitPtr(op, adaptor, rewriter); + } + + return failure(); + } +}; + +struct LoadConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + void createSideBySideCopies(Value block1, Value block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + void createStackedCopies(Value block1, Value block2, Value dst, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + memref::SubViewOp createSubview(Value src, ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, Location loc, + ConversionPatternRewriter &rewriter) const { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, cast(dstType), + src, offsets, sizes, strides); + } + + std::pair + getSideBySideSubviews(ArrayRef dims, Value block1, Value block2, + Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = + rewriter.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); + OpFoldResult subviewCol2 = + subOFRs(subviewColFull, subviewCol1, loc, rewriter); + + SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); + SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, offsets, {subviewRowFull, subviewCol1}, + strides, loc, rewriter); + auto sv2 = createSubview(block2, offsets, {subviewRowFull, subviewCol2}, + strides, loc, rewriter); + + return {sv1, sv2}; + } + + std::pair + getStackedSubviews(ArrayRef dims, Value block1, Value block2, + const Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = + rewriter.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); + OpFoldResult subviewRow2 = + subOFRs(subviewRowFull, subviewRow1, loc, rewriter); + + SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); + SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, offsets, {subviewRow1, subviewColFull}, + strides, loc, rewriter); + auto sv2 = createSubview(block2, offsets, {subviewRow2, subviewColFull}, + strides, loc, rewriter); + return {sv1, sv2}; + } + + LogicalResult + rewriteStructuredLoad(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(!op.hasMask()); + + auto loc = op->getLoc(); + auto ptr = adaptor.getPtr(); + auto other = op.getOther(); + + auto tensorType = cast(op.getType()); + auto elemType = tensorType.getElementType(); + + auto alloc = rewriter.create( + loc, MemRefType::get(tensorType.getShape(), elemType)); + + // No mask + assert(!other && "other value used in non-masked load"); + + if (auto unrealizedCast = ptr.getDefiningOp()) { + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { + createSideBySideCopies(block1, block2, alloc, loc, rewriter); + } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { + createStackedCopies(block1, block2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + } else { + rewriter.create(loc, ptr, alloc); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + + LogicalResult rewriteMaskedLoad(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(op.hasMask()); + + auto loc = op->getLoc(); + auto ptr = adaptor.getPtr(); + + auto tensorType = cast(op.getType()); + auto elemType = tensorType.getElementType(); + + auto alloc = rewriter.create( + loc, MemRefType::get(tensorType.getShape(), elemType)); + + SmallVector mixedDims = op.getMixedMaskDims(); + + // Fill load destination with other value + if (op.getOther()) { + // For each dimension check if dims[i] < shape[i], or-accumulate + // the result + auto shape = tensorType.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < shape.size(); i++) { + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + + Value dimi = dyn_cast(mixedDims[i]); + if (!dimi) { + dimi = rewriter.create( + loc, rewriter.getIndexAttr(op.getStaticMaskDims()[i])); + } + + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = rewriter.create(loc, accBase, cmp); + } + + // condition the memset on the or-accumulation + // initialize with padding prior to CopyOp + rewriter.create(loc, accBase, [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{op.getOther()}, + ValueRange{alloc}); + b.create(loc); + }); + } + + if (auto unrealizedCast = ptr.getDefiningOp()) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { + auto [subview1, subview2] = + getSideBySideSubviews(mixedDims, block1, block2, loc, rewriter); + createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); + } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { + auto [subview1, subview2] = + getStackedSubviews(mixedDims, block1, block2, loc, rewriter); + createStackedCopies(subview1, subview2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + + rewriter.eraseOp(unrealizedCast); + + } else { + memref::SubViewOp srcSubview = + getSubview(tensorType.getRank(), mixedDims, ptr, loc, rewriter); + memref::SubViewOp dstSubview = + getSubview(tensorType.getRank(), mixedDims, alloc, loc, rewriter); + rewriter.create(loc, srcSubview, dstSubview); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + +public: + LogicalResult + matchAndRewrite(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.hasMask()) { + return rewriteMaskedLoad(op, adaptor, rewriter); + } else { + return rewriteStructuredLoad(op, adaptor, rewriter); + } + } +}; + +struct StoreConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + static tensor::ExtractSliceOp + getExtractSlice(int rank, ArrayRef dims, Value source, + const Location loc, OpBuilder &b) { + auto sourceType = cast(source.getType()); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector strides(rank, b.getIndexAttr(1)); + + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, + dims, strides); + + return b.create(loc, dstType, source, offsets, dims, + strides); + } + +public: + LogicalResult + matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptr = adaptor.getPtr(); + auto storeValue = op.getValue(); + auto rank = cast(storeValue.getType()).getRank(); + + if (op.hasMask()) { + auto mixedDims = op.getMixedMaskDims(); + + auto srcSlice = + getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); + auto dstSubview = getSubview(rank, mixedDims, ptr, loc, rewriter); + + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + } else { + auto storeOp = rewriter.create( + loc, storeValue, ptr); + storeOp.setWritable(true); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ScalarLoadConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = op->getLoc(); + auto memrefPtr = adaptor.getPtr(); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + auto loadOp = rewriter.create(loc, memrefPtr, zeroMap, + std::nullopt); + rewriter.replaceOp(op, loadOp.getResult()); + + return success(); + } +}; + +struct ScalarStoreConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getValue().getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = op->getLoc(); + auto memrefPtr = adaptor.getPtr(); + auto val = op.getValue(); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + + rewriter.create(loc, val, memrefPtr, zeroMap, + std::nullopt); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct UnrealizedCastConverter + : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + UnrealizedCastConverter(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, + context) {} + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resType = op->getResultTypes()[0]; + auto input = op.getInputs()[0]; + auto inputType = input.getType(); + + if (!isa(resType) || + !isa(inputType)) { + return failure(); + } + + if (auto reinterpretCast = + input.getDefiningOp()) { + rewriter.replaceOp(op, reinterpretCast); + } else { + auto ptrType = cast(resType); + auto memrefType = + cast(getTypeConverter()->convertType(ptrType)); + + auto cast = rewriter.create( + op->getLoc(), memrefType, op.getInputs()[0], 0 /*offset*/, + SmallVector{1} /*sizes*/, + SmallVector{1} /*strides*/); + + rewriter.replaceOp(op, cast); + } + + return success(); + } +}; + +} // namespace + +void mlir::triton::populateStructuredToMemrefConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter) { + patterns.add(typeConverter, patterns.getContext()); + patterns.add( + patterns.getContext()); +} diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp new file mode 100644 index 000000000..7decf7148 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -0,0 +1,414 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Transforms/OneToNTypeConversion.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +#include +#include + +#define DEBUG_TYPE "structured-to-memref" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_STRUCTUREDTOMEMREF +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +static MemRefType getMemrefTypeForScalarPtr(triton::PointerType ptrType, + MLIRContext *context) { + SmallVector strides{1}; + auto layout = StridedLayoutAttr::get(context, ShapedType::kDynamic, strides); + + auto elemType = ptrType.getPointeeType(); + auto memrefType = MemRefType::get({1}, elemType, layout); + return memrefType; +} + +class TritonFunctionSignatureConverter : public TypeConverter { +public: + TritonFunctionSignatureConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + // Used for converting memref<*> back to tt.ptr type, these ops will then be + // handled when we convert addptr op later. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + + addArgumentMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + } +}; + +class LoopTypeConverter : public TypeConverter { +public: + LoopTypeConverter(MLIRContext *context) { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([context](triton::PointerType ptrType) { + return getMemrefTypeForScalarPtr(ptrType, context); + }); + + // A tensor of pointers can be passed in as scf.for's init-args, in such + // cases, we convert the type to a memref with dynamic offsets and + // strides. + addConversion( + [context](RankedTensorType tensorType) -> std::optional { + if (auto ptrType = llvm::dyn_cast( + tensorType.getElementType())) { + auto layout = StridedLayoutAttr::get( + context, ShapedType::kDynamic, + SmallVector(tensorType.getRank(), + ShapedType::kDynamic)); + Type elemType = ptrType.getPointeeType(); + return MemRefType::get(tensorType.getShape(), elemType, layout); + } + + return std::nullopt; + }); + + // Convert the current memref type to a memref type with dynamic offsets and + // strides through another reinterpret_cast with the same offsets. + // Canonicalization will simplify this sequence by removing the inital + // reinterpret_cast. + addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType, + ValueRange inputs, Location loc) -> Value { + auto reinterpretCast = + inputs[0].getDefiningOp(); + return builder.create( + loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0], + reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides()); + }); + } +}; + +struct ScalarAddptrConverter + : public OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + if (isa(op.getType())) { + return failure(); + } + + auto loc = op->getLoc(); + + auto offsetIndex = rewriter.create( + loc, rewriter.getIndexType(), op.getOffset()); + + auto ptrInfo = adaptor.getPtr(); + assert(ptrInfo.size() == 2); + auto ptr = ptrInfo[0]; + auto offset = ptrInfo[1]; + + auto newOffset = rewriter.create(loc, offset, offsetIndex); + + auto castOp = rewriter.create( + loc, + getMemrefTypeForScalarPtr( + cast(op.getPtr().getType()), + rewriter.getContext()), + ptr, getAsOpFoldResult(newOffset) /*offset*/, + ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, + ArrayRef{rewriter.getIndexAttr(1)} /*strides*/); + + rewriter.replaceOp(op, SmallVector{castOp.getResult(), newOffset}, + adaptor.getResultMapping()); + + return success(); + } +}; + +static std::optional> +buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + assert(resultTypes.size() == 2 && isa(resultTypes[0]) && + isa(resultTypes[1]) && + "Unexpected result types when converting addptr"); + assert(isa(input.getType()) && + "Unexpected input type when converting addptr"); + + // There are only two types of ops that can produce a result of type tt.ptr + // 1) tt.addptr, this is already handled by ScalarAddptrConverter + // 2) unrealized_conversion_cast, which are inserted during the conversion + // of function arguments. + // We assert that there can only be input that comes from + // unrealized_conversion_cast. + auto castOp = input.getDefiningOp(); + assert(castOp && "Unexpected defining op for input of type tt.ptr"); + + // Compute the memref type + auto buffer = castOp.getOperand(0); + auto bufferType = cast(buffer.getType()); + auto layout = + StridedLayoutAttr::get(builder.getContext(), ShapedType::kDynamic, {1}); + auto memrefType = MemRefType::get({1}, bufferType.getElementType(), layout); + + // Create ops to convert the triton input type to a pair of {memref, index} + auto cast = builder.create( + loc, memrefType, buffer, 0 /*offset*/, ArrayRef{(1)} /*sizes*/, + ArrayRef{(1)} /*strides*/); + auto zero = builder.create(loc, builder.getIndexAttr(0)); + + return SmallVector{cast, zero}; +} + +static Value buildCastOp(OpBuilder &builder, Type resultType, ValueRange inputs, + Location loc) { + assert(isa(resultType)); + assert(inputs.size() && isa(inputs[0].getType()) && + isa(inputs[1].getType())); + return builder.create(loc, resultType, inputs[0]) + .getResult(0); +} + +class StructuredToMemrefPass + : public triton::impl::StructuredToMemrefBase { + using StructuredToMemrefBase::StructuredToMemrefBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + LogicalResult convertArgsToMemrefType() { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonFunctionSignatureConverter typeConverter; + + // Update function signatures and calls to use memrefs + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + target.addDynamicallyLegalOp([&](func::CallOp op) { + return typeConverter.isLegal(op.getResultTypes()) && + typeConverter.isLegal(op.getOperandTypes()); + }); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + return applyPartialConversion(moduleOp, target, std::move(patterns)); + } + + // We leverage the 1->N conversion infrastructure to convert tt.addptr for + // scalar to memref.reinterpret_cast. + // + // A tt.addptr has the following form: + // + // %new_ptr = tt.addptr %ptr %offset + // + // where %new_ptr and %ptr have tt.ptr type, and %offset is of index type. + // + // With this form, there can be a chain of tt.addptr where we keep adding + // offsets to an existing pointer: + // + // %ptr_1 = tt.addptr %arg0 %offset + // %ptr_2 = tt.addptr %ptr_1 %offset + // %ptr_3 = tt.addptr %ptr_2 %offset + // + // Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that + // the pointers can be used by affine.load and affine.store (lowered from + // tt.load and tt.store). + // + // A memref.reinterpret_cast op also takes an offset and returns a memref in a + // similar fashion to tt.addptr: + // + // %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: + // [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: + // ?>> + // + // However, since the semantic of memref.reinterpret_cast is different, + // the following lowering would be incorrect for the sequence of tt.addptr + // above: + // + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset] + // %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset] + // + // The above sequence is equivalent to: + // + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset] + // + // In other word, memref.reinterpret_cast ignores the current offset of the + // input buffer. + // + // Therefore, we have to manually track the offset for each addptr by lowering + // to the following form: + // + // %offset_1 = arith.addi %cst_0 %offset + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1] + // + // %offset_2 = arith.addi %offset_1 %offset + // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2] + // + // %offset_3 = arith.addi %offset_2 %offset + // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3] + // + // Each tt.addptr is lowered to a pair of arith.addi that accumulates the + // current offset before using that offset to the reinterpret_cast. + LogicalResult convertAddPtrToReinterpretCast() { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->2 type conversion here, where a triton pointer type + // maps to a pair of {memref, index} type for the the buffer and offset. + converter.addConversion( + [context](triton::PointerType ptrType, SmallVectorImpl &types) + -> std::optional { + types = SmallVector{getMemrefTypeForScalarPtr(ptrType, context), + IndexType::get(context)}; + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert a pair of {memref, + // index} type back to the original triton pointer type. + // These are used when there are ops that still need to use the original + // pointer type. For instance, we convert the result of tt.addptr from + // tt.ptr type to a pair of {memref, index}, but the original ptr result is + // still being used by another tt.load or tt.store. + converter.addArgumentMaterialization(buildCastOp); + converter.addSourceMaterialization(buildCastOp); + + // Compute the target materialization, given a value with the pointer type, + // convert that value to a pair of {memref, index} type. +#if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization(buildCastAndOffsetOps); +#endif + + patterns.add(converter, context); + + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + return failure(); + } + + return success(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + if (failed(convertArgsToMemrefType())) { + signalPassFailure(); + return; + } + + if (failed(convertAddPtrToReinterpretCast())) { + signalPassFailure(); + return; + } + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, ttx::TritonTilingExtDialect, + memref::MemRefDialect>(); + + target.addIllegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + auto resType = op->getResultTypes()[0]; + return !isa(resType); + }); + + LoopTypeConverter loopTypeConverter(patterns.getContext()); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + loopTypeConverter, patterns, target); + + triton::populateStructuredToMemrefConversionPatterns(patterns, + loopTypeConverter); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + // Erase dead code and fold constants created during lowering + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createStructuredToMemrefPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt new file mode 100644 index 000000000..f20b2102c --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonArithToLinalg + TritonArithToLinalg.cpp + TritonArithToLinalgPass.cpp + + DEPENDS + TritonArithToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRLinalgTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp new file mode 100644 index 000000000..637732fc7 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include + +#define DEBUG_TYPE "triton-arith-to-linalg" +#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +void mlir::triton::populateTritonArithToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.add, MinMaxConverter>( + patterns.getContext()); +} + +void mlir::triton::populateTritonArithToLinalgConversionPatterns( + bool pidsToFuncArgs, bool addptrToLinalg, bool assertToCf, + RewritePatternSet &patterns) { + + if (pidsToFuncArgs) { + // Need use tx interface to get pid. + patterns.add( + patterns.getContext()); + } + if (addptrToLinalg) { + patterns.add(patterns.getContext()); + } + if (assertToCf) { + patterns.add(patterns.getContext()); + } + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + populateExternElementwiseOpToMLIROps(patterns); + + // Reduce converters + // Triton's reduce op is idential to linalg.reduce op, so we can clone + // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to + // perform pattern matching to know what reduce ops we are dealing with + // so that we know how to initialize the initial reduce values correctly. + // + // We can do this in a generic way without pattern matching by always using + // the first elements along the reduction axis and perform the reduction on + // the remaining elements. However, this results in creatings sub-tensors that + // aren't always multiple of 2s, which are sub-optimal for certain hardwares. + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + // Note: the ordering here matters! + // These patterns are added last to they will be tried last. + linalg::populateElementwiseToLinalgConversionPatterns(patterns); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp new file mode 100644 index 000000000..bae1bd6ba --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -0,0 +1,228 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-arith-to-linalg" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_TRITONARITHTOLINALG +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class TritonArithToLinalgPass + : public triton::impl::TritonArithToLinalgBase { + using TritonArithToLinalgBase< + TritonArithToLinalgPass>::TritonArithToLinalgBase; + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + // Add additional I32 arguments to represent: + // - num_programs, 3 in total, one for each axis of the launch grid + // - program_id, 3 in total, one for each axis of the launch grid + static void addProgramInfo(triton::FuncOp func) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // Add empty attributes for each new argument if needed + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // Add the corresponding arguments to function body + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + } + + LogicalResult applyTensorConcatDecomposition() { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + tensor::populateDecomposeTensorConcatPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + return failure(); + } + return success(); + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + { + RewritePatternSet patterns(&getContext()); + populateTritonArithToLinalgCanonicalizationPatterns(patterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, ttx::TritonTilingExtDialect, + tts::TritonStructuredDialect, mk::MagicKernelDialect>(); + + target.addLegalOp(); + + target.addLegalOp(); + + target.addDynamicallyLegalDialect( + [](Operation *op) { + // Lower dense constant to linalg.fill + if (auto constOp = dyn_cast(op)) { + if (!isa(constOp.getResult().getType())) { + return true; + } + + if (auto denseAttr = + dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat() && + isa(denseAttr.getElementType())) { + return false; + } + } + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), [](Type type) { + return isa(type); + }); + + return !operateOnTensors; + }); + + if (pidsToFuncArgs) { + target + .addIllegalOp(); + } + + if (addptrToLinalg) { + target.addDynamicallyLegalOp([](triton::AddPtrOp op) { + return !isa(op.getResult().getType()); + }); + } + + if (!assertToCf) { + target.addLegalOp(); + } + + triton::populateTritonArithToLinalgConversionPatterns( + pidsToFuncArgs, addptrToLinalg, assertToCf, patterns); + + if (pidsToFuncArgs) { + for (auto func : getOperation().getOps()) { + addProgramInfo(func); + } + } + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + if (failed(applyTensorConcatDecomposition())) { + signalPassFailure(); + } + + // Convert tt.func and tt.return into func's counterparts + if (ttToFuncFunc) { + moduleOp.walk([&](triton::FuncOp func) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + auto funcFunc = builder.create(func.getLoc(), name, type); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + // Only convert to func.return if the terminator is a tt.return. + // Otherwise, we will accidentally convert cf.br ops which are also + // considered terminators. + if (isa(term)) { + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + } + func.erase(); + }); + } + } +}; + +} // namespace + +std::unique_ptr> +triton::createTritonArithToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt new file mode 100644 index 000000000..d703f9150 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt @@ -0,0 +1,31 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +add_triton_library(TritonToCoreDialects + TritonToCoreDialectsPass.cpp + + DEPENDS + TritonToCoreDialectsConversionPassIncGen + + LINK_LIBS PUBLIC + TritonTilingExtIR + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonAnalysis + TritonIR + TritonTransforms + ZTCAnalysis + + TritonArithToLinalg + StructuredToMemref + TritonToStructured +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp new file mode 100644 index 000000000..0f4b8dfc6 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h.inc" + +namespace { + +class TritonToCoreDialectsPass + : public TritonToCoreDialectsBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createTritonToStructuredPass()); + + // Erase dead code and fold constants created during lowering + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addPass(createTritonArithToLinalgPass()); + pm.addPass(createStructuredToMemrefPass()); + + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToCoreDialectsPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..e9ebf49c3 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,28 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +add_triton_library(TritonToLinalg + TritonToLinalg.cpp + TritonToLinalgPass.cpp + + DEPENDS + TritonToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + TritonTilingExtIR + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonAnalysis + TritonIR + TritonTransforms + ZTCAnalysis +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp new file mode 100644 index 000000000..ea6d32593 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" + +#define DEBUG_TYPE "triton-to-linalg" +#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.add, MinMaxConverter>( + patterns.getContext()); +} + +void mlir::triton::populateTritonToLinalgConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + unsigned int launchGridRank) { + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + populateExternElementwiseOpToMLIROps(patterns); + + // Reduce converters + // Triton's reduce op is idential to linalg.reduce op, so we can clone + // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to + // perform pattern matching to know what reduce ops we are dealing with + // so that we know how to initialize the initial reduce values correctly. + // + // We can do this in a generic way without pattern matching by always using + // the first elements along the reduction axis and perform the reduction on + // the remaining elements. However, this results in creatings sub-tensors that + // aren't always multiple of 2s, which are sub-optimal for certain hardwares. + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + // Note: the ordering here matters! + // MetaOpConverter has PatternBenefit == 10 which should take precedence over + // these linalg patterns, but to be safe, add these patterns last so that they + // will be tried last. Incorrect ordering or having MetaOpConverter has lower + // PatternBenefit will result in element-wise meta ops being converted to + // linalg.generic ops. + linalg::populateElementwiseToLinalgConversionPatterns(patterns); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp new file mode 100644 index 000000000..25b7db85f --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp @@ -0,0 +1,229 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/UseAnalysis.h" +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-to-linalg" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +namespace { + +class TritonTypeConverter : public TypeConverter { +public: + TritonTypeConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + addConversion([](TensorType tensorType) -> Type { + auto elemType = tensorType.getElementType(); + if (auto ptrType = dyn_cast(elemType)) { + elemType = ptrType.getPointeeType(); + } + return MemRefType::get(tensorType.getShape(), elemType); + }); + } +}; + +class TritonToLinalgPass : public TritonToLinalgBase { + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + // Add additional I32 arguments to represent: + // - num_programs, 3 in total, one for each axis of the launch grid + // - program_id, 3 in total, one for each axis of the launch grid + static void addProgramInfo(triton::FuncOp func) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // Add empty attributes for each new argument if needed + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // Add the corresponding arguments to function body + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + { + RewritePatternSet patterns(&getContext()); + populateTritonToLinalgCanonicalizationPatterns(patterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } + + moduleOp.walk([this](triton::FuncOp op) { + if (failed(runUseAnalysis(op))) { + signalPassFailure(); + } + }); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonTypeConverter tritonTypeConverter; + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, + ttx::TritonTilingExtDialect>(); + + target.addLegalOp(); + + // Update function signature to use memrefs + target.addDynamicallyLegalOp([&](triton::FuncOp op) { + return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); + }); + + // Lower dense constant to linalg.fill + target.addDynamicallyLegalOp([](arith::ConstantOp op) { + if (!isa(op.getResult().getType())) { + return true; + } + + if (auto denseAttr = dyn_cast(op.getValue())) { + if (denseAttr.isSplat() && + isa(denseAttr.getElementType())) { + return false; + } + } + return true; + }); + + target.addDynamicallyLegalOp([](Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type t) { + if (isa(t)) { + return false; + } + if (auto shapedType = dyn_cast(t)) { + return shapedType.getElementType().isIntOrFloat(); + } + assert(t.isIntOrIndexOrFloat()); + return true; + }); + }); + + target.addDynamicallyLegalDialect( + [](Operation *op) { + if (op->hasAttr("MetaUse")) { + return false; + } + + if (isa(op)) { + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), [](Type type) { + return isa(type); + }); + + return !operateOnTensors; + }); + + triton::populateTritonToLinalgConversionPatterns( + tritonTypeConverter, patterns, LAUNCH_GRID_RANK); + + for (auto func : getOperation().getOps()) + addProgramInfo(func); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + + // Convert tt.func and tt.return into func's counterparts + moduleOp.walk([&](triton::FuncOp func) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + auto funcFunc = builder.create(func.getLoc(), name, type); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + func.erase(); + }); + + // Erase dead code and fold constants created during lowering + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> triton::createTritonToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 000000000..743d09138 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonToStructured + TritonToStructuredPass.cpp + + DEPENDS + TritonStructuredTableGen + TritonToStructuredConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + MLIRReconcileUnrealizedCasts + TritonIR + TritonTransforms + ZTCAnalysisStructured + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp new file mode 100644 index 000000000..71d694290 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -0,0 +1,344 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include + +#define DEBUG_TYPE "triton-to-structured" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +namespace { + +class TritonToStructuredPass + : public TritonToStructuredBase { + + static TupleType getStructuredStateTupleType(MLIRContext *context, Type t) { + SmallVector tupleTypes{t}; + auto [offsetTypes, strideTypes] = + *tts::GetStructuredStateOp::getOffsetAndStrideTypes(context, t); + tupleTypes.append(offsetTypes); + tupleTypes.append(strideTypes); + return TupleType::get(context, tupleTypes); + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + LogicalResult convertToPointerTupleWithOffsetsAndStrides() { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->1 type conversion here, where a triton pointer type + // maps to a tuple of {pointer, offset_0, offset_1,..., stride_0, + // stride_1,...} type. + // + // Case 1: Unstructured pointers (tensor>) + converter.addConversion([context](RankedTensorType tensorType, + SmallVectorImpl &types) + -> std::optional { + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + if (!isa(tensorType.getElementType()) && + !tensorType.getElementType().isIntOrIndex()) { + // There's a subtle difference between returning failure() and + // std::nullopt. From the documentation: + // + // If std::nullopt is returned, the converter is allowed to try another + // conversion function to perform the conversion. + // + // Say we have type tensor<4x256xbf16> which is a RankedTensorType. Even + // though this RankedTensorType matches the converter that handles the + // tuple conversion, we want to keep this type as is because the inner + // type isn't a pointer. + // + // By returning failure(), the TypeConverters will stop trying the + // remaining converters. In our case, the last type converter which + // simply returns the same type is skipped. And because the conversion + // for this type has failed, the whole conversion process is also + // skipped. + // + // Relevant links to the implementation: + // + // https://github.com/llvm/llvm-project/blob/cb5dc1faa8b3702e0d03426ee5dfc5e1b903ec47/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2958 + // https://github.com/llvm/llvm-project/blob/cb5dc1faa8b3702e0d03426ee5dfc5e1b903ec47/mlir/lib/Transforms/Utils/DialectConversion.cpp#L3033 + return std::nullopt; + } + types = + SmallVector{getStructuredStateTupleType(context, tensorType)}; + return success(); + }); + + // Case 2: Block pointers (!tt.ptr> or !tt.ptr) + converter.addConversion([context](triton::PointerType ptrType, + SmallVectorImpl &types) + -> std::optional { + types = SmallVector{getStructuredStateTupleType(context, ptrType)}; + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert the tuple type back to + // the original triton pointer type. These are used when there are ops that + // still need to use the original pointer type. For instance, we convert the + // result of tt.addptr from tt.ptr type to a tuple, but the original ptr + // result is still being used by another tt.load or tt.store. + auto materialize = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + + converter.addArgumentMaterialization(materialize); + converter.addSourceMaterialization(materialize); + + // Compute the target materialization, given a value with the pointer type, + // convert that value to a tuple type. +#if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) -> std::optional> { + return builder + .create(loc, resultTypes, input) + ->getResults(); + }); +#endif + + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + return failure(); + } + + return success(); + } + + LogicalResult decomposePointerTuple() { + auto moduleOp = getOperation(); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->N type conversion here, where a pointer tuple type + // maps to a sequence of {pointer, offset_0, offset_1,..., stride_0, + // stride_1,...} + converter.addConversion( + [context](TupleType tupleType, SmallVectorImpl &types) + -> std::optional { + tupleType.getFlattenedTypes(types); + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert a series of {pointer, + // offset_0, offset_1,..., stride_0, stride_1,...} type back to the "pointer + // tuple type". + // + // Because we actually want to get rid of the tuple type, return `inputs[0]` + // which corresponds to a "triton pointer type". This approach will work as + // intended because the ops that currently take "pointer tuple type" are + // `unrealized_conversion_cast` ops which will get removed below during + // reconcile-unrealized-conversion-casts. + auto materialize = [](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) { return inputs[0]; }; + converter.addArgumentMaterialization(materialize); + converter.addSourceMaterialization(materialize); + + // For each value of "pointer tuple type" that gets decomposed into a + // sequence of {pointer, offset_0, offset_1,..., stride_0, stride_1,...}, + // create a `tts.get_structured_state` op that serves as a placeholder. + // The return values for this op will be used as the init-args for scf.for. + // At the end of pointer analysis, we will use the PtrState to create the + // correct offsets, strides, and remove these ops. +#if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization([](OpBuilder &builder, + TypeRange resultTypes, Value input, + Location loc) { + auto placeholder = builder.create( + loc, input.getDefiningOp()->getOperand(0)); + assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); + return placeholder.getResults(); + }); +#endif + + RewritePatternSet patterns(&getContext()); + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + // Note: + // Be careful not to run canonicalization here, because the + // tts.get_structured_state ops created above are just placeholders and + // don't have any effects. Canonicalization will remove them altogether. + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + + return success(); + } + + // Prepass that inserts `tts.get_structured_state` ops. These ops are used as + // placeholders to make passing structured pointer state into scf.for loop's + // init args easier, especially with multiple levels of loops. + // + // Background: + // + // PtrAnalysis computes a PtrState for every operand (or triton value) + // involved in a sequence of pointer arithmetic; some examples include: triton + // pointer, offsets (which could be a tensor of indices or just a simple index + // value). + // + // If a triton value is updated and returned in a scf.for op, it means + // that we have to carry its offsets and strides in the scf.for's iterargs. + // + // Previously, we have to manually rewrite the loops to include the + // relevant information from a PtrState which was rather involved and + // error-prone; this was also hard to scale up to multiple level of loops + // because there are several book-keeping data structures that we have to + // maintain. + // + // With the introduction of the prepass that inserts + // `tts.get_structured_state`. The return values of these ops, which include a + // triton value with its original result type and its corresponding offsets + // and strides, will be used as "placeholders" into the scf.for's init-args. + // We leverage standard MLIR infrastructure 1->N conversion to perform this + // rewrite, which helps simplify the logic significantly. + // + // After PtrAnalysis finishes, the return values of these + // `tts.get_structured_state` ops will be remapped to the correct + // initialization of the value's offsets and strides through the value's + // computed PtrState. + // + // Implementation details: + // In essence, what we really want to do in the prepass is, for every value + // of triton-pointer-like type (tt.ptr or tensor>) and tensor of + // indices (tensor) which might be used in a sequence of pointer + // arithmetic, we want to create an op `tts.get_structured_state` that takes + // in the original triton value and returns a series of values: + // + // {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...} + // + // Applying the above conversion will also mean that any structural ops such + // as scf.for and scf.yield that originally takes the triton pointer will + // then take {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...}. + // + // The 1->N type conversion is a perfect fit for this transformation. + // Unfortunately, we cannot do this is one pass, because the current 1->N + // type conversion implementation for scf.for ops doesn't provide us with a + // way to detect that a type conversion is recursive. So a triton_value type + // that gets converted to a {triton_value, offset_0, offset_1, ..., stride_0, + // stride_1,...} will recursively trigger other conversions. + // + // To fix this issue, we have to first convert triton_value to + // tuple. + // Finally, we decompose these tuples into the desired sequence. + // + // Note that even though the type conversion happens for every integer tensor + // appearing in loops' iter-args, this conversion is reversible. If the + // integer tensor isn't used in a pointer arithmetic sequence, + // canonicalization will remove all the `tts.get_structured_state` ops and + // revert the IR back to its original form. + LogicalResult runTritonToStructuredPrepass() { + if (failed(convertToPointerTupleWithOffsetsAndStrides())) { + return failure(); + } + + return decomposePointerTuple(); + } + + void runOnOperation() override { + if (!skipPrepass && failed(runTritonToStructuredPrepass())) { + signalPassFailure(); + return; + } + + if (runPrepassOnly) { + return; + } + + auto moduleOp = getOperation(); + mlir::tts::PtrAnalysis ptrAnalysis; + ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); + + if (failed(ptrAnalysis.rewriteOp(moduleOp, useUnsafeMask))) { + moduleOp->emitWarning("PtrAnalysis failed"); + } + + // Now that all the PtrStates have been populated, we can wire up the states + // with the tts.get_structured_state ops inserted in the prepass. + moduleOp.walk([&ptrAnalysis](tts::GetStructuredStateOp op) { + if (failed(ptrAnalysis.rewriteGetStructuredStateOp(op))) { + op.emitWarning("Rewriting GetStructuredStateOp failed."); + } + }); + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToStructuredPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt new file mode 100644 index 000000000..9a20a8c10 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(Tx81MemrefToLLVM + Tx81MemrefToLLVM.cpp + Tx81MemrefToLLVMPass.cpp + + DEPENDS + Tx81MemrefToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp new file mode 100644 index 000000000..a857fbb3e --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp @@ -0,0 +1,349 @@ +//===------------------- Tx81MemrefToLLVM.cpp------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include +#include + +#define DEBUG_TYPE "tx81-memref-to-llvm" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +// Used for allocate spm memory +uint64_t spmPointer = 0x10000; + +namespace { +// Used for kcore load/store data from/to spm +const int64_t spmMappingOffset = 0x30400000; + +//===----------------------------------------------------------------------===// +// Tx81 Custom MemRef Op Conversion Patterns +//===----------------------------------------------------------------------===// + +struct TsmMemRefAllocOpLowering : public AllocLikeOpLLVMLowering { + TsmMemRefAllocOpLowering(const LLVMTypeConverter &converter) + : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), + converter) {} + + std::tuple + allocateBufferFromSPM(ConversionPatternRewriter &rewriter, Location loc, + Operation *op) const { + // create GEPOp for spm address. + MemRefType memRefType = getMemRefResultType(op); + Value spmOffsetOp = rewriter.create( + loc, getIndexType(), rewriter.getI32IntegerAttr(spmPointer)); + Type elementType = typeConverter->convertType(memRefType.getElementType()); + auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + Value spmAddr = rewriter.create(loc, elementPtrType); + + spmAddr = rewriter.create(op->getLoc(), + rewriter.getI64Type(), spmAddr); + spmAddr = rewriter.create(op->getLoc(), rewriter.getI64Type(), + spmAddr, spmOffsetOp); + + spmAddr = rewriter.create(op->getLoc(), elementPtrType, + spmAddr); + Value allocatedPtr = spmAddr; + if (!allocatedPtr) + return std::make_tuple(Value(), Value()); + Value alignedPtr = allocatedPtr; + + // update spm pointer + auto elemCount = memRefType.getNumElements(); + auto bitWidth = memRefType.getElementTypeBitWidth(); + auto allocOp = dyn_cast(op); + if (allocOp.getAlignment().has_value()) + bitWidth = allocOp.getAlignment().value(); + uint64_t totalByte = (elemCount * bitWidth + 7) / 8; + spmPointer += totalByte; + + return std::make_tuple(allocatedPtr, alignedPtr); + } + + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, + Operation *op) const override { + return allocateBufferFromSPM(rewriter, loc, op); + } +}; + +template +struct MemrefLoadOrStoreOpLowering : public ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using OpAdaptor = typename MemrefOp::Adaptor; + + LogicalResult + matchAndRewrite(MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = op.getMemRefType(); + + Value dataPtr = ConvertToLLVMPattern::getStridedElementPtr( + op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); + + // TODO: Add spm offset according the memory space + MemRefDescriptor memRefDescriptor(adaptor.getMemref()); + auto intPtrType = ConvertToLLVMPattern::getIntPtrType( + memRefDescriptor.getElementPtrType().getAddressSpace()); + Value ptrValue = + rewriter.create(op.getLoc(), intPtrType, dataPtr); + + // Workaround: Should add memory space analysis pass. + Operation *opBase = op; + if (!opBase->hasAttr("isSpm")) { + return rewriter.notifyMatchFailure( + op, "Load/Store should have isSpm attribute."); + } + int isSpm = + cast(opBase->getAttr("isSpm")).getValue().getSExtValue(); + + Value adjustedPtr = dataPtr; + if (isSpm) { + auto spmMemoryOffset = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(spmMappingOffset)); + auto spmMemoryAddr = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + SmallVector({ptrValue, spmMemoryOffset})); + + auto ptrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), + *ConvertToLLVMPattern::getTypeConverter()->getMemRefAddressSpace( + type)); + auto spmMemoryAddrPtr = + rewriter.create(op.getLoc(), ptrTy, spmMemoryAddr); + + adjustedPtr = spmMemoryAddrPtr; + } + + // Wether need memoryspace cast + if constexpr (std::is_same()) { + + rewriter.replaceOpWithNewOp(op, op.getType(), adjustedPtr, + 0, false, op.getNontemporal()); + } else { + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), adjustedPtr, 0, false, op.getNontemporal()); + } + + return success(); + } +}; + +struct MemRefReinterpretCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = castOp.getSource().getType(); + + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(castOp, {descriptor}); + return success(); + } + +private: + /// Extracts allocated, aligned pointers and offset from a ranked or unranked + /// memref type. In unranked case, the fields are extracted from the + /// underlying ranked descriptor. + void extractPointersAndOffset(Location loc, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &typeConverter, + Value originalOperand, Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr, + Value *offset = nullptr) const { + Type operandType = originalOperand.getType(); + if (isa(operandType)) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + if (offset != nullptr) + *offset = desc.offset(rewriter, loc); + return; + } + + // These will all cause assert()s on unconvertible types. + unsigned memorySpace = *typeConverter.getMemRefAddressSpace( + cast(operandType)); + auto elementPtrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + + // FIXME: workaround, take memRefDescPtr as naked ptr. + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + *allocatedPtr = underlyingDescPtr; + *alignedPtr = underlyingDescPtr; + + if (offset != nullptr) { + *offset = rewriter.create( + loc, getIndexType(), rewriter.getI32IntegerAttr(0)); + } + } + + LogicalResult convertSourceMemRefToDescriptor( + ConversionPatternRewriter &rewriter, Type srcType, + memref::ReinterpretCastOp castOp, + memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { + MemRefType targetMemRefType = + cast(castOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); + if (!llvmTargetDescriptorTy) + return failure(); + + // Create descriptor. + Location loc = castOp.getLoc(); + auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + castOp.getSource(), adaptor.getSource(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Set offset. + if (castOp.isDynamicOffset(0)) + desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); + else + desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); + + // Set sizes and strides. + unsigned dynSizeId = 0; + unsigned dynStrideId = 0; + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + if (castOp.isDynamicSize(i)) + desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); + else + desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); + + if (castOp.isDynamicStride(i)) + desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); + else + desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); + } + *descriptor = desc; + return success(); + } +}; + +/// Materialize the MemRef descriptor represented by the results of +/// ExtractStridedMetadataOp. +class ExtractStridedMetadataOpLowering + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) + return failure(); + + // Create the descriptor. + MemRefDescriptor sourceMemRef(adaptor.getSource()); + Location loc = extractStridedMetadataOp.getLoc(); + Value source = extractStridedMetadataOp.getSource(); + + auto sourceMemRefType = cast(source.getType()); + int64_t rank = sourceMemRefType.getRank(); + SmallVector results; + results.reserve(2 + rank * 2); + + // Base buffer. + Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); + Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); + MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), + cast(extractStridedMetadataOp.getBaseBuffer().getType()), + baseBuffer, alignedBuffer); + results.push_back((Value)dstMemRef); + + // Offset. + results.push_back(sourceMemRef.offset(rewriter, loc)); + + // Sizes. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.size(rewriter, loc, i)); + // Strides. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.stride(rewriter, loc, i)); + + rewriter.replaceOp(extractStridedMetadataOp, results); + return success(); + } +}; + +/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. +class ConvertExtractAlignedPointerAsIndex + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + BaseMemRefType sourceTy = extractOp.getSource().getType(); + + // FIXME: We want allocated ptr instead of aligned ptr. + Value alignedPtr; + if (sourceTy.hasRank()) { + MemRefDescriptor desc(adaptor.getSource()); + alignedPtr = desc.allocatedPtr(rewriter, extractOp->getLoc()); + } else { + auto elementPtrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); + + UnrankedMemRefDescriptor desc(adaptor.getSource()); + Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); + + alignedPtr = UnrankedMemRefDescriptor::allocatedPtr( + rewriter, extractOp->getLoc(), descPtr, elementPtrTy); + } + + rewriter.replaceOpWithNewOp( + extractOp, getTypeConverter()->getIndexType(), alignedPtr); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateTx81MemrefToLLVMConversionPatterns( + RewritePatternSet &patterns, LLVMTypeConverter &converter) { + // clang-format off + patterns.add, + MemrefLoadOrStoreOpLowering>( + converter); + // clang-format on +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp new file mode 100644 index 000000000..7eaf1da0d --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp @@ -0,0 +1,93 @@ +//===------------------- Tx81MemrefToLLVMPass.cpp--------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "tx81-memref-to-llvm" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class Tx81MemrefToLLVMPass + : public mlir::triton::Tx81MemrefToLLVMBase { + using Tx81MemrefToLLVMBase::Tx81MemrefToLLVMBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + target.addIllegalOp(); + + target.addLegalDialect(); + + target.addLegalOp(); + + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + LLVMTypeConverter llvmTypeConverter(context, options); + triton::populateTx81MemrefToLLVMConversionPatterns(patterns, + llvmTypeConverter); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + // Record spm usage. + moduleOp->setAttr("triton_tsm.spm_use", + mlir::IntegerAttr::get( + mlir::IntegerType::get(context, 32), spmPointer)); + } +}; + +} // namespace + +std::unique_ptr> triton::createTx81MemrefToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt new file mode 100644 index 000000000..441858a94 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -0,0 +1,25 @@ +add_triton_library(Tx81ToLLVM + Tx81ToLLVM.cpp + KernelArgBufferPass.cpp + + DEPENDS + Tx81ToLLVMConversionPassIncGen + KernelArgBufferPassIncGen + MLIRMemRefToLLVM + + LINK_LIBS PUBLIC + MLIRMemRefToLLVM + MLIRArithDialect + MLIRArithToLLVM + MLIRFuncDialect + MLIRFuncToLLVM + MLIRLLVMDialect + MLIRMemRefDialect + MLIRMemRefToLLVM + MLIRArithToLLVM + MLIRAffineToStandard + MLIRLinalgToStandard + MLIRSCFDialect + MLIRSCFToControlFlow + MLIRTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp new file mode 100644 index 000000000..9458fdc95 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp @@ -0,0 +1,143 @@ +//===- KernelArgBufferPass.cpp - Convert kernel args to single buffer -----===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms kernel function signatures by converting multiple +// arguments into a single void* buffer containing all the arguments. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class KernelArgBufferPass + : public mlir::triton::KernelArgBufferPassBase { + using KernelArgBufferPassBase::KernelArgBufferPassBase; + +public: + StringRef getArgument() const final { return "kernel-arg-buffer"; } + StringRef getDescription() const final { + return "Convert kernel arguments to a single buffer argument"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + +private: + // Insert load op to get real kernel args from new buffered argument + // Side effect: calculate offset and create ops + Value insertKernelArgLoad(OpBuilder &builder, Location loc, Value argsBuffer, + Type argType, int64_t ¤tOffset); +}; + +Value KernelArgBufferPass::insertKernelArgLoad(OpBuilder &builder, Location loc, + Value argsBuffer, Type argType, + int64_t ¤tOffset) { + // Get pointer to the current position in args buffer + auto offsetValue = builder.create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); + + // NOTE: GEPOp need distinguish the scalar and ptr type. So here ptr + offset + Value elementPtr = + builder.create(loc, builder.getI64Type(), argsBuffer); + elementPtr = builder.create(loc, builder.getI64Type(), + elementPtr, offsetValue); + elementPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), elementPtr); + + // Increment offset. Assume all args are 8 bytes + currentOffset += sizeof(int64_t); + + // Load the real kernel arg value + return builder.create(loc, argType, elementPtr); +} + +void KernelArgBufferPass::runOnOperation() { + ModuleOp module = getOperation(); + OpBuilder builder(module.getContext()); + + // Collect functions to process + SmallVector kernelFuncs; + for (auto func : module.getOps()) { + kernelFuncs.push_back(func); + } + // NOTE: We move this pass before tx81-to-llvm pass. + // So we assume the func op must be only one and must be the triton kernel + assert(kernelFuncs.size() == 1 && "Only one kernel function expected"); + + // Process each kernel function + // TODO: Delete the for loop if the assert is always true for all examples + for (auto func : kernelFuncs) { + // Create new function with bufferized signature + builder.setInsertionPointAfter(func); + // Save the old block arguments + SmallVector blockArguments = + llvm::to_vector<8>(func.getArguments()); + auto numArguments = blockArguments.size(); + + // New bufferized arg type + auto voidPtrType = LLVM::LLVMPointerType::get(builder.getContext()); + + // New bufferized function type + auto newFuncType = LLVM::LLVMFunctionType::get( + func.getFunctionType().getReturnType(), voidPtrType); + func.setFunctionType(newFuncType); + SmallVector newArgAttrs({DictionaryAttr()}); + func.setAllArgAttrs(newArgAttrs); + + // Add the new bufferized argument + Location loc = func.getLoc(); + Block &entryBlock = func.getBlocks().front(); + entryBlock.insertArgument((unsigned)0, voidPtrType, func.getLoc()); + + OpBuilder builder(&entryBlock, entryBlock.begin()); + // Get the bufferized argument + Value argsBuffer = entryBlock.getArgument(0); + + // Offset tracking for buffer access + int64_t currentOffset = 0; + + // Process each original argument + for (auto argIndex : llvm::seq(0, numArguments)) { + auto oldArg = blockArguments[argIndex]; + Type argType = oldArg.getType(); + Value loadedArg = insertKernelArgLoad(builder, func.getLoc(), argsBuffer, + argType, currentOffset); + + if (blockArguments[argIndex].use_empty()) + continue; + oldArg.replaceAllUsesWith(loadedArg); + } + // Remove the old arguments when replace the use-chain + entryBlock.eraseArguments(1, numArguments); + } +} + +} // namespace + +std::unique_ptr triton::createKernelArgBufferPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp new file mode 100644 index 000000000..68c7e75ca --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -0,0 +1,1631 @@ +//===--------------------- Tx81ToLLVM.cpp ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements the patterns to convert operations from tx dialect to +// LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "tx81-to-llvm" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +namespace { +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// +// Crt func name +const char addVVFuncName[] = "__AddVV"; +const char subVVFuncName[] = "__SubVV"; +const char mulVVFuncName[] = "__MulVV"; +const char divVVFuncName[] = "__DivVV"; +const char absVVFuncName[] = "__AbsVV"; +const char rsqrtVVFuncName[] = "__RsqrtVV"; +const char sqrtVVFuncName[] = "__SqrtVV"; +const char lnFuncName[] = "__Ln"; +const char log2FuncName[] = "__Log2"; +const char expFuncName[] = "__Exp"; +const char pow2FuncName[] = "__Pow2"; +const char sinFuncName[] = "__Sin"; +const char cosFuncName[] = "__Cos"; +const char addVSFuncName[] = "__AddVS"; +const char subVSFuncName[] = "__SubVS"; +const char mulVSFuncName[] = "__MulVS"; +const char divVSFuncName[] = "__DivVS"; +const char reduceSumFuncName[] = "__ReduceSum"; +const char reduceMaxFuncName[] = "__ReduceMax"; +const char reduceMinFuncName[] = "__ReduceMin"; +const char fp16ToFp32FuncName[] = "__FP16_FP32"; +const char int16ToFp16FuncName[] = "__INT16_FP16"; +const char int16ToFp32FuncName[] = "__INT16_FP32"; +const char int32ToFp16FuncName[] = "__INT32_FP16"; +const char int32ToFp32FuncName[] = "__INT32_FP32"; +const char fp16ToInt8FuncName[] = "__FP16_INT8"; +const char fp16ToInt16FuncName[] = "__FP16_INT16"; +const char fp16ToInt32FuncName[] = "__FP16_INT32"; +const char fp32ToInt8FuncName[] = "__FP32_INT8"; +const char fp32ToInt16FuncName[] = "__FP32_INT16"; +const char fp32ToInt32FuncName[] = "__FP32_INT32"; +const char boolEqualVVFuncName[] = "__BoolEqualVV"; +const char boolUnEqualVVFuncName[] = "__BoolUnEqualVV"; +const char boolGreaterEqualVVFuncName[] = "__BoolGreaterEqualVV"; +const char boolGreaterVVFuncName[] = "__BoolGreaterVV"; +const char boolLessEqualVVFuncName[] = "__BoolLessEqualVV"; +const char boolLessVVFuncName[] = "__BoolLessThenVV"; +const char fp32ToFp16FuncName[] = "__FP32_FP16"; +const char fp32ToBf16FuncName[] = "__FP32_BF16"; +const char fp32ToTF32FuncName[] = "__FP32_TF32"; +const char andVVFuncName[] = "__AndVV"; +const char orVVFuncName[] = "__OrVV"; +const char xorVVFuncName[] = "__XorVV"; +const char MaxVVFuncName[] = "__MaxVV"; +const char MinVVFuncName[] = "__MinVV"; + +// Function to declare Tx81 runtime function +Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, + StringRef name, Type resultType, + ArrayRef argumentTypes) { + // Check if the function already exists + Operation *funcOp = module.lookupSymbol(name); + if (funcOp) + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); + + // Create function type + Type funcType = LLVM::LLVMFunctionType::get(resultType, argumentTypes, + /*isVarArg=*/false); + + // Create a function declaration + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(module.getBody()); + + builder.create(loc, name, funcType, + LLVM::Linkage::External); + + builder.restoreInsertionPoint(ip); + + // Return function pointer + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); +} + +static Value adjustElemCountType(ConversionPatternRewriter &rewriter, + Location loc, Value elemCount) { + Value newElemCount = elemCount; + if (isa(elemCount.getType())) { + newElemCount = rewriter.create( + loc, rewriter.getI32Type(), elemCount); + } else if (isa(elemCount.getType())) { + auto elemCountType = dyn_cast(elemCount.getType()); + if (elemCountType.isInteger(64)) + newElemCount = rewriter.create( + loc, rewriter.getI32Type(), elemCount); + } + return newElemCount; +} + +static Value castIndexToInt32(ConversionPatternRewriter &rewriter, Location loc, + Value indexOp) { + return rewriter.create(loc, rewriter.getI32Type(), + indexOp); +} + +//===----------------------------------------------------------------------===// +// Arith Operation Conversion Patterns +//===----------------------------------------------------------------------===// + +// Convert constant operations to LLVM constants +struct ConstantOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the constant value + auto constAttr = op.getValue(); + + // Get the result type + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); + + // Handle different attribute types + if (auto intAttr = mlir::dyn_cast(constAttr)) { + // Convert integer attribute + rewriter.replaceOpWithNewOp(op, resultType, intAttr); + return success(); + } else if (auto floatAttr = mlir::dyn_cast(constAttr)) { + // Convert float attribute + rewriter.replaceOpWithNewOp(op, resultType, floatAttr); + return success(); + } else if (auto boolAttr = mlir::dyn_cast(constAttr)) { + // Convert bool attribute to i1 + rewriter.replaceOpWithNewOp( + op, resultType, + rewriter.getIntegerAttr(resultType, boolAttr.getValue())); + return success(); + } + + return failure(); + } +}; + +// Convert arith.index_cast to appropriate LLVM conversions +struct IndexCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get source and result types + auto srcType = adaptor.getIn().getType(); + auto dstType = getTypeConverter()->convertType(op.getResult().getType()); + + // Convert from index to specific integer type + if (mlir::isa(srcType) && + mlir::isa(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + return success(); + } + + // Convert from specific integer type to index + if (mlir::isa(srcType) && + mlir::isa(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + return success(); + } + + // Handle integer to integer casts + if (mlir::isa(srcType) && mlir::isa(dstType)) { + unsigned srcWidth = mlir::cast(srcType).getWidth(); + unsigned dstWidth = mlir::cast(dstType).getWidth(); + + if (srcWidth < dstWidth) { + // Sign extend if source is signed, zero extend otherwise + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); + } else if (srcWidth > dstWidth) { + // Truncate + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + } else { + // Same width, just pass through + rewriter.replaceOp(op, adaptor.getIn()); + } + return success(); + } + + return failure(); + } +}; + +// Convert arith.addi to LLVM add +struct AddIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +// Convert arith.muli to LLVM mul +struct MulIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Tx81 Operation Conversion Patterns +//===----------------------------------------------------------------------===// + +// Convert tx81.rdma to LLVM call to crt __Rdma function +struct RdmaOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::RdmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Rdma runtime function if not already declared + /* + void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int + shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int + *strides, uint32_t fmt) + */ + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src + i8PtrTy, // target + i32Ty, // shape_n + i32Ty, // shape_h + i32Ty, // shape_w + i32Ty, // shape_c + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Rdma", + i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + // Get the operands + Value target = adaptor.getTarget(); + target = rewriter.create(op.getLoc(), i8PtrTy, target); + + ValueRange shape = adaptor.getShape(); + Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + + ValueRange strides = adaptor.getStrides(); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __Rdma + auto call = rewriter.create( + op.getLoc(), TypeRange{i8PtrTy}, "__Rdma", // funcPtr, + ValueRange{src, target, shape0, shape1, shape2, shape3, stride0, + stride1, stride2, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.wdma to LLVM call to __Wdma function +struct WdmaOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::WdmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Wdma runtime function if not already declared + /* + void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int + shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int + *strides, uint32_t fmt) + */ + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src + i8PtrTy, // target + i32Ty, // shape_n + i32Ty, // shape_h + i32Ty, // shape_w + i32Ty, // shape_c + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Wdma", + i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + // Get the operands + Value target = adaptor.getTarget(); + + // Need to bitcast src to i8* + target = rewriter.create(op.getLoc(), i8PtrTy, target); + + ValueRange shape = adaptor.getShape(); + Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + + ValueRange strides = adaptor.getStrides(); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __Wdma + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Wdma", // funcPtr, + ArrayRef{src, target, shape0, shape1, shape2, shape3, stride0, + stride1, stride2, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.mask_move to LLVM call to __MaskMove function +struct MaskMoveOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::MaskMoveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __MaskMove runtime function if not already declared + // Signature: void* __MaskMove(void* source, void* target, uint32_t + // elem_count, int32_t* masks, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // source + i8PtrTy, // target + i32Ty, // elem_count + i32PtrTy, // masks + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__MaskMove", i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + Value target = adaptor.getTarget(); + + // Need to bitcast src to i8* + target = rewriter.create(op.getLoc(), i8PtrTy, target); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); + + // Handle mask arrays + // For simplicity, we'll create empty arrays + Value nullPtr = rewriter.create(op.getLoc(), i32PtrTy); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __MaskMove + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__MaskMove", // funcPtr, + ArrayRef{src, target, elemCount, nullPtr, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.binary op to LLVM call +template +struct ReduceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + // __ReduceSum(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + // uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty, + i16Ty, i16Ty, i16Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSrc(); + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value srcB = adaptor.getSrc(); + Value dst = adaptor.getDst(); + // Need to bitcast src to i8* + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + + // Convert dim attribute to Value + Value dim = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getDim())); + + // Convert shape attribute to Value + Value shape_n = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[0]); + Value shape_h = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[1]); + Value shape_w = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[2]); + Value shape_c = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[3]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{src, dst, dim, shape_n, shape_h, shape_w, shape_c, + fmt}); + + // Erase the old op + rewriter.eraseOp(op); + + return success(); + } +}; + +// Convert tx81.elementwise op to LLVM call +template +struct ElementWiseOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + // using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Add(void* a, void* b, void* out, uint32_t elem_count, + // uint32_t rnd_mode, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, + + i32Ty, i32Ty, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle round attribute + Value rnd_mode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getRndMode())); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, rnd_mode, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +template +struct UnaryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Abs(void* src, void* dst, uint32_t elem_count, + // uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + // Need to bitcast src to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + Value out = adaptor.getOut(); + // Need to bitcast out to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// FIXME: Use trait to refactor the BinaryVSOpConversion and +// ElementWiseOpConversion +template +struct BinaryVSOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Add(void* a, void* b, void* out, uint32_t elem_count, + // uint32_t rnd_mode, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i32Ty, i8PtrTy, + i32Ty, i32Ty, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + + Value srcB = adaptor.getValue(); + + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle round attribute + Value rnd_mode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getRndMode())); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, rnd_mode, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +template +struct BinaryLogicVVOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __XorVV(void* a, void* b, void* out, uint32_t + // elem_count, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src0_addr + i8PtrTy, // src1_addr + i8PtrTy, // dst_addr + i32Ty, // elem_count + i32Ty // fmt + }; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +template +struct BoolRelationVVOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename BoolRelationVVOp::Adaptor; + // using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BoolRelationVVOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __BoolLessEqualVV(uint64_t *src0, uint64_t *src1, + // uint64_t *dst, uint32_t elem_count, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.NormalConvertOp op to LLVM +template +struct NormalConvertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename NormalConvertOp::Adaptor; + + LogicalResult + matchAndRewrite(NormalConvertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __FP16_FP32(uint64_t *src, uint64_t *dst, uint32_t + // elem_count); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + Value output = adaptor.getOutput(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, output, elemCount}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.RoundConvertOp op to LLVM +template +struct RoundConvertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename RoundConvertOp::Adaptor; + + LogicalResult + matchAndRewrite(RoundConvertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t + // elem_count, RND_MODE round); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + Value output = adaptor.getOutput(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + Value rnd_mode = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getRndMode())); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, output, elemCount, rnd_mode}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.gemm to LLVM call to __Gemm function +struct GemmOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::GemmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Gemm runtime function if not already declared + // Signature: void __Gemm(int64_t* srcA, int64_t *srcB, int64_t * srcBias, + // int64_t *dst, int32_t *dims, bool enPsum, int64_t *psum, bool enTransA, + // bool enTransB, int64_t batchSizeA, int64_t batchSizeB, bool enLeakyRelu, + // bool enBias,bool enNegScale, int64_t *negScale, bool enPosScale, int64_t + // *posScale, int64_t srcFmt, int64_t dstFmt) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i64Ty = rewriter.getI64Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i1Ty = rewriter.getI1Type(); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // srcA + i8PtrTy, // srcB + i8PtrTy, // srcBias + i8PtrTy, // dst + i32PtrTy, // dims + i1Ty, // enPsum + i8PtrTy, // psum + i1Ty, // enTransA + i1Ty, // enTransB + i32Ty, // batchSizeA + i32Ty, // batchSizeB + i32Ty, // reluMode + i1Ty, // enBias + i1Ty, // enNegScale + i8PtrTy, // negScale + i1Ty, // enPosScale + i8PtrTy, // posScale + i32Ty, // srcFmt + i32Ty // dstFmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Gemm", + i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getSrcA(); + Value srcB = adaptor.getSrcB(); + Value srcBias = adaptor.getSrcBias(); + Value dst = adaptor.getDst(); + + Value psumAddr = adaptor.getPsumAddr(); + Value srcNegScale = adaptor.getSrcNegScale(); + Value srcPosScale = adaptor.getSrcPosScale(); + + // Bitcast all pointers to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + srcBias = rewriter.create(op.getLoc(), i8PtrTy, srcBias); + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + psumAddr = + rewriter.create(op.getLoc(), i8PtrTy, psumAddr); + srcNegScale = + rewriter.create(op.getLoc(), i8PtrTy, srcNegScale); + srcPosScale = + rewriter.create(op.getLoc(), i8PtrTy, srcPosScale); + + // Handle dims array - need to convert from attribute to runtime array + auto dimsAttr = op.getDims(); + SmallVector dimsValues; + for (auto dimAttr : dimsAttr) + dimsValues.push_back(mlir::cast(dimAttr).getInt()); + + // Allocate memory for the dims array + Value dimsArraySize = rewriter.create( + op.getLoc(), i64Ty, rewriter.getI64IntegerAttr(dimsValues.size())); + + // Use alloc to allocate memory for dims array + auto dimsArrayI32Ptr = rewriter.create( + op.getLoc(), i32PtrTy, rewriter.getI32Type(), dimsArraySize, + /*alignment=*/0); + + // Store each dimension in the array + for (size_t i = 0; i < dimsValues.size(); i++) { + // Create the index + Value idx = rewriter.create( + op.getLoc(), i64Ty, rewriter.getI32IntegerAttr(i)); + + // Create GEP to get pointer to array element + Value elemPtr = rewriter.create( + op.getLoc(), i64PtrTy, i32Ty, dimsArrayI32Ptr, ArrayRef{idx}); + + // Create the dimension value + Value dimValue = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(dimsValues[i])); + + // Store the value + rewriter.create(op.getLoc(), dimValue, elemPtr); + } + + // Convert boolean attributes + Value transA = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcA())); + Value transB = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcB())); + Value enPSum = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnPsum())); + Value reluMode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getReluMode())); + Value enBias = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnBias())); + Value enNegScale = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnNegScale())); + Value enPosScale = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnPosScale())); + + // Convert integer attributes + Value batchA = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getBatchSrcA())); + Value batchB = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getBatchSrcB())); + Value srcFmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getSrcFmt())); + Value dstFmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getDstFmt())); + + // Create the call to __Gemm + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Gemm", // funcPtr, + ArrayRef{srcA, srcB, srcBias, dst, dimsArrayI32Ptr, enPSum, + psumAddr, transA, transB, batchA, batchB, reluMode, + enBias, enNegScale, srcNegScale, enPosScale, + srcPosScale, srcFmt, dstFmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.memset to LLVM call to __Memset function +struct MemsetOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::MemsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Memset runtime function if not already declared + // Signature: void* __Memset(void* dst, int64_t value, uint32_t elem_count, + // int32_t* strides, int32_t* iterations, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Ty = rewriter.getI64Type(); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // Spm addr + i32Ty, // value + i32Ty, // shape_n/iterator_2 + i32Ty, // shape_h/iterator_1 + i32Ty, // shape_w/iterator_0 + i32Ty, // shape_c/elem_count + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w, + i16Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__Memset", i8PtrTy, argTypes); + + // Get operands + Value src = adaptor.getSrc(); + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + Value value = adaptor.getValue(); + + // Handle strides and iterations arrays + ValueRange shape = adaptor.getShape(); + Value iteration2 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value iteration1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value iteration0 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value elemCount = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + + ValueRange strides = adaptor.getStrides(); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + + // Convert fmt attribute to Value + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call to __Memset + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Memset", // funcPtr, + ArrayRef{src, value, elemCount, stride0, iteration0, stride1, + iteration1, stride2, iteration2, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Conversion pattern for linalg.fill operation with tensor arguments +struct LinalgFillOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinalgFillOpConversion(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(linalg::FillOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The operation should have tensor as output + if (op.getOutputs().size() != 1) { + return rewriter.notifyMatchFailure(op, "expects single output tensor"); + } + + // Check if the output is a tensor type + Value outputTensor = op.getOutputs()[0]; + auto tensorType = mlir::dyn_cast(outputTensor.getType()); + if (!tensorType) { + return rewriter.notifyMatchFailure(op, "expects ranked tensor type"); + } + + // Check for static shape + if (!tensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic shapes not yet supported"); + } + + auto context = rewriter.getContext(); + auto loc = op.getLoc(); + Value value = adaptor.getInputs()[0]; + + // Get the element type + Type elemType = tensorType.getElementType(); + + // Convert the tensor type to the LLVM pointer type + auto llvmPtrType = mlir::dyn_cast( + typeConverter->convertType(tensorType)); + if (!llvmPtrType) { + return rewriter.notifyMatchFailure( + op, "failed to convert tensor type to LLVM pointer type"); + } + + // Calculate total number of elements + int64_t totalElements = 1; + for (int64_t dim : tensorType.getShape()) { + totalElements *= dim; + } + + // Get index type + auto indexType = rewriter.getI64Type(); + + // Implement the following steps: + // 1. Allocate memory for the tensor + // 2. Fill it using memset if applicable + // 3. Return the pointer as the result + + // Calculate element size in bytes + int64_t elemSizeInBytes = 0; + if (auto intType = mlir::dyn_cast(elemType)) { + elemSizeInBytes = + (intType.getWidth() + 7) / 8; // Round up to nearest byte + } else if (auto floatType = mlir::dyn_cast(elemType)) { + elemSizeInBytes = + (floatType.getWidth() + 7) / 8; // Round up to nearest byte + } else { + return rewriter.notifyMatchFailure(op, "unsupported element type"); + } + + // Calculate total size in bytes + auto totalSizeInBytes = totalElements * elemSizeInBytes; + auto totalSizeVal = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(totalSizeInBytes)); + + // Allocate memory + auto mallocFunc = + getOrInsertMalloc(rewriter, op->getParentOfType()); + auto allocated = rewriter.create( + loc, LLVM::LLVMPointerType::get(context), mallocFunc, + ArrayRef{totalSizeVal}); + + auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); + + // Cast the allocated memory to the appropriate pointer type + auto castPtr = rewriter.create(loc, llvmPtrType, + allocated.getResult()); + + // Check if we can use memset for filling + bool useMemset = false; + Value byteValue; + + // For memset to work correctly, we need to have a consistent byte pattern + if (auto constOp = value.getDefiningOp()) { + if (auto intAttr = mlir::dyn_cast(constOp.getValue())) { + // For integer constants + auto intVal = intAttr.getInt(); + // Check if all bytes in the pattern are the same + bool allBytesEqual = true; + uint8_t firstByte = intVal & 0xFF; + for (unsigned i = 1; i < elemSizeInBytes; i++) { + if (((intVal >> (i * 8)) & 0xFF) != firstByte) { + allBytesEqual = false; + break; + } + } + + if (allBytesEqual) { + useMemset = true; + byteValue = rewriter.create( + loc, rewriter.getIntegerType(8), + rewriter.getIntegerAttr(rewriter.getIntegerType(8), firstByte)); + } + } else if (auto floatAttr = + mlir::dyn_cast(constOp.getValue())) { + // For floating point constants + if (floatAttr.getValue().isZero()) { + // Zero float can use memset with zero byte value + useMemset = true; + byteValue = rewriter.create( + loc, rewriter.getIntegerType(8), rewriter.getI8IntegerAttr(0)); + } + } + } + + if (useMemset) { + // Use memset for filling + auto memsetFunc = + getOrInsertMemset(rewriter, op->getParentOfType()); + rewriter.create( + loc, llvmVoidPtr, memsetFunc, + ArrayRef{castPtr, byteValue, totalSizeVal}); + } else { + // Create a loop to manually fill the tensor with the value + // We'll use SCF dialect for structured loops + auto llvmElemType = typeConverter->convertType(elemType); + + // Create loop initialization + auto zero = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(0)); + auto upperBound = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(totalElements)); + auto one = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(1)); + + // Create the fill loop + auto loopOp = + rewriter.create(loc, zero, upperBound, one, ValueRange{}); + + // Set insertion point inside the loop + rewriter.setInsertionPointToStart(loopOp.getBody()); + + // Calculate pointer for the current element + auto currentPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), castPtr, + ArrayRef({loopOp.getInductionVar()})); + + // Store the fill value to the current memory location + rewriter.create(loc, value, currentPtr); + + // Reset insertion point after the loop + rewriter.setInsertionPointAfter(loopOp); + } + + // Replace the original op with the casted pointer + rewriter.replaceOp(op, castPtr); + return success(); + } + +private: + // Helper to get or insert malloc function declaration + FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, + ModuleOp module) const { + auto context = rewriter.getContext(); + auto mallocName = "malloc"; + if (module.lookupSymbol(mallocName)) { + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Create malloc function declaration + auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); + auto mallocType = + LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), mallocName, mallocType); + + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Helper to get or insert memset function declaration + FlatSymbolRefAttr getOrInsertMemset(PatternRewriter &rewriter, + ModuleOp module) const { + auto context = rewriter.getContext(); + auto memsetName = "memset"; + if (module.lookupSymbol(memsetName)) { + return SymbolRefAttr::get(rewriter.getContext(), memsetName); + } + + // Create memset function declaration + auto voidPtrType = LLVM::LLVMPointerType::get(context); + auto memsetType = LLVM::LLVMFunctionType::get( + context, + voidPtrType, // memset returns the destination pointer + ArrayRef{voidPtrType, rewriter.getI8Type(), + rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), memsetName, memsetType); + + return SymbolRefAttr::get(rewriter.getContext(), memsetName); + } +}; + +// Conversion pattern for tensor.empty operation +class TensorEmptyOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + TensorEmptyOpConversion(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the result tensor type + TensorType resultType = op.getType(); + + // Verify we can handle this tensor type + if (!resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic shapes not yet supported"); + } + + // Convert the tensor type to LLVM pointer type + auto llvmPtrType = mlir::dyn_cast( + getTypeConverter()->convertType(resultType)); + + if (!llvmPtrType) { + return rewriter.notifyMatchFailure( + op, "failed to convert tensor type to LLVM pointer type"); + } + + // Get element type + Type elementType = resultType.getElementType(); + + // Create LLVM operations to allocate memory + // 1. Calculate the total allocation size in bytes + auto loc = op.getLoc(); + int64_t totalElements = 1; + for (int64_t dim : resultType.getShape()) { + totalElements *= dim; + } + + auto elementSize = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.getI64IntegerAttr(getElementTypeSize(elementType))); + + auto totalSize = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(totalElements)); + + auto allocSize = rewriter.create(loc, rewriter.getI64Type(), + totalSize, elementSize); + + // 2. Allocate memory using malloc + auto mallocFunc = + getOrInsertMalloc(rewriter, op->getParentOfType()); + auto allocated = rewriter.create(loc, llvmPtrType, mallocFunc, + ArrayRef{allocSize}); + + // Replace the tensor.empty operation with our allocation + rewriter.replaceOp(op, allocated.getResult()); + return success(); + } + +private: + // Helper to get element type size in bytes + int64_t getElementTypeSize(Type type) const { + if (auto floatType = mlir::dyn_cast(type)) { + return floatType.getWidth() / 8; + } else if (auto intType = mlir::dyn_cast(type)) { + return intType.getWidth() / 8; + } + // Default for other types + return 1; + } + + // Helper to get or insert malloc function declaration + FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, + ModuleOp module) const { + auto mallocName = "malloc"; + if (module.lookupSymbol(mallocName)) { + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Create malloc function declaration + auto llvmVoidPtr = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto mallocType = + LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), mallocName, mallocType); + + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } +}; + +// Convert tt.get_program_id to LLVM call to __get_pid function +// Think this as Tx81 special action. May can separate to a single pass or use +// tx81.get_program_id op +struct GetProgramIDConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + static uint32_t constexpr LAUNCH_GRID_RANK = + mlir::triton::getMaxEnumValForProgramIDDim() + 1; + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Memset runtime function if not already declared + // Signature: uint32_t __get_pid(uint32_t); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = { + i32Ty, // x: 0/y: 1/z: 2, + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__get_pid", i32Ty, argTypes); + + // Get operands + auto axis = (uint32_t)op.getAxis(); + + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + // Convert fmt attribute to Value + Value src = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(axis)); + + // Create the call to __Memset + auto call = rewriter.create(op.getLoc(), i32Ty, + "__get_pid", // funcPtr, + ArrayRef{src}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// The conversion pass +class Tx81ToLLVMPass : public Tx81ToLLVMBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + + // Setup LLVM lowering options object which should live across the call to + // applyFull/PartialConversion. + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + + // Setup conversion target + target.addLegalDialect(); + // Handle the tx81 op to llvm.call and support kcore load/store op's spm + // offset + target.addIllegalDialect(); + + // Setup rewrite patterns + RewritePatternSet patterns(context); + + // NOTE: LLVMTypeConverter should be enough for MLIR core dialects. + LLVMTypeConverter llvmTypeConverter(context, options); + + // Add the Tx81 to LLVM conversion patterns + // clang-format off + patterns.add, + NormalConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + ReduceOpConversion, + ReduceOpConversion, + ReduceOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BinaryLogicVVOpConversion, + BinaryLogicVVOpConversion, + BinaryLogicVVOpConversion, + RdmaOpConversion, + WdmaOpConversion, + MaskMoveOpConversion, + GemmOpConversion, + MemsetOpConversion, + GetProgramIDConversion>( + context); + // clang-format on + + // Add call op conversion + populateCallOpTypeConversionPattern(patterns, llvmTypeConverter); + + // Add return op conversion + populateReturnOpTypeConversionPattern(patterns, llvmTypeConverter); + + // Apply the conversion + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> triton::createTx81ToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp new file mode 100644 index 000000000..e0efab8d0 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp @@ -0,0 +1,78 @@ +//===--------------------- Tx81ToLLVMPass.cpp -----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "tx81-to-llvm" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +namespace { + +class Tx81ToLLVMPass : public Tx81ToLLVMBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + + // Setup LLVM lowering options object which should live across the call to + // applyFull/PartialConversion. + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + + // Setup conversion target + target.addLegalDialect(); + target.addIllegalDialect(); + + // Setup rewrite patterns + RewritePatternSet patterns(context); + + // NOTE: LLVMTypeConverter should be enough for MLIR core dialects. + TensorToLLVMTypeConverter converter(context, options); + + triton::populateTx81ToLLVMConversionPatterns(patterns, target, converter); + + // Apply the conversion + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> triton::createTx81ToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Dialect/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4c1e72494 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) +add_subdirectory(MagicKernel) +add_subdirectory(TsingMicroTx81) diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt new file mode 100644 index 000000000..41f904167 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(MagicKernelIR + IR/MagicKernelDialect.cpp + Transforms/BufferizableOpInterfaceImpl.cpp + + DEPENDS + MagicKernelTableGen + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp new file mode 100644 index 000000000..d71761179 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp @@ -0,0 +1,33 @@ +//===------------------- MagicKernelDialect.cpp ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" + +using namespace mlir; +using namespace mlir::mk; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void MagicKernelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "magic-kernel/Dialect/IR/MagicKernelOps.cpp.inc" + >(); + // TODO: Add BufferizableOpInterface to all ops that can be bufferized + declarePromisedInterfaces(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "magic-kernel/Dialect/IR/MagicKernelOps.cpp.inc" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..f0c256956 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,122 @@ +//===- BufferizableOpInterfaceImpl.cpp ----------------------------------- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements mk dialect DestinationStyleOp BufferizableOpInterface. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; +using namespace mlir::bufferization; + +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult +bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, + DestinationStyleOpInterface op, + const BufferizationOptions &options) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasPureTensorSemantics()) + return op->emitError() << "op does not have pure tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{}, + op->getAttrs()); + + Operation *newOp = Operation::create(state); + + // We don't want the rewriter tracks an incomplete operation, so insert new + // operation after op was fully constructed. + rewriter.insert(newOp); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +/// Bufferization of mk ops. Replace with a new mk op that operates entirely on +/// memrefs. +template +struct MKOpInterface + : public DstBufferizableOpInterfaceExternalModel, + OpTy> { + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + return bufferizeDestinationStyleOpInterface( + rewriter, cast(op), options); + } +}; + +/// Helper structure that iterates over all mkOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template struct MKOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; + +void mlir::mk::registerBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, mlir::mk::MagicKernelDialect *dialect) { + // TODO: Register all mk ops. + MKOpInterfaceHelper::registerOpInterface(ctx); + }); +} diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 000000000..27aac38fa --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonStructuredIR + TritonStructuredOps.cpp + TritonStructuredDialect.cpp + + DEPENDS + TritonStructuredTableGen + + LINK_LIBS PUBLIC + TritonIR + MLIRIR + ) diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp new file mode 100644 index 000000000..2af19b8a2 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp @@ -0,0 +1,22 @@ +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +using namespace mlir; +using namespace mlir::tts; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void TritonStructuredDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp new file mode 100644 index 000000000..cf55d834a --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -0,0 +1,179 @@ +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +using namespace mlir; +using namespace mlir::tts; + +namespace mlir { +namespace tts { + +void MakeTensorPtrOp::build(OpBuilder &b, OperationState &state, Value base, + ArrayRef sizes, + ArrayRef strides, + ArrayRef offsets, + ArrayRef shape, + ArrayRef order) { + SmallVector staticStrides, staticOffsets, staticShape; + SmallVector dynamicStrides, dynamicOffsets, dynamicShape; + + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); + + Type resType; + auto basePtr = cast(base.getType()); + auto elemType = basePtr.getPointeeType(); + // non-block pointer + if (order.empty()) { + resType = RankedTensorType::get(sizes, basePtr); + } + // block pointer + else { + resType = triton::PointerType::get(RankedTensorType::get(sizes, elemType), + basePtr.getAddressSpace()); + } + + build(b, state, resType, base, sizes, dynamicStrides, dynamicOffsets, + dynamicShape, b.getDenseI64ArrayAttr(staticStrides), + b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticShape), order); +} + +void LoadOp::build(OpBuilder &b, OperationState &state, Value ptr, + ArrayRef dims, Value other) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + // non-block pointer type + auto ptrTensorType = dyn_cast(ptr.getType()); + // block pointer type + auto tensorPtrType = dyn_cast(ptr.getType()); + + Type resType; + if (ptrTensorType) { + auto ptrType = cast(ptrTensorType.getElementType()); + auto elemType = ptrType.getPointeeType(); + resType = RankedTensorType::get(ptrTensorType.getShape(), elemType); + + } else if (tensorPtrType) { + auto tensorType = cast(tensorPtrType.getPointeeType()); + resType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType()); + } + build(b, state, resType, ptr, dynamicDims, b.getDenseI64ArrayAttr(staticDims), + other); +} + +void StoreOp::build(OpBuilder &b, OperationState &state, Value ptr, Value value, + ArrayRef dims) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + build(b, state, ptr, value, dynamicDims, b.getDenseI64ArrayAttr(staticDims)); +} + +LogicalResult GetStructuredStateOp::verify() { + auto expectedOffsetAndStrideTypes = + getOffsetAndStrideTypes(getContext(), getInput().getType()); + + if (!expectedOffsetAndStrideTypes.has_value()) { + return failure(); + } + + auto [expectedOffsetTypes, expectedStrideTypes] = + *expectedOffsetAndStrideTypes; + + return success(expectedOffsetTypes.size() == getOffsets().size() && + llvm::equal(expectedOffsetTypes, getOffsets().getTypes()) && + expectedStrideTypes.size() == getStrides().size() && + llvm::equal(expectedStrideTypes, getStrides().getTypes())); +} + +void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, + Value val) { + auto type = val.getType(); + + // Builder cannot fail, so we default to empty offset and stride types. + // The invalid op will be rejected by the verifier later. + auto [offsetTypes, strideTypes] = + getOffsetAndStrideTypes(b.getContext(), type) + .value_or(std::make_pair(SmallVector{}, SmallVector{})); + + build(b, state, val.getType(), offsetTypes, strideTypes, val); +} + +std::optional, SmallVector>> +GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, Type type) { + auto sizes = getOffsetAndStrideSegmentSizes(type); + if (!sizes.has_value()) { + return std::nullopt; + } + return std::make_pair( + SmallVector(sizes->first, IndexType::get(context)), + SmallVector(sizes->second, IndexType::get(context))); +} + +std::optional> +GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type type) { + int32_t offsetSegmentSize = 0; + int32_t strideSegmentSize = 0; + + if (auto tensorType = llvm::dyn_cast(type)) { + if (tensorType.getElementType().isIntOrIndex()) { + // Tensors of offsets + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + // Unstructured pointers (tensor>) + // Each tensor of rank k gets k values for its offsets and k values for + // its strides, all of which has Index type. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } + } + // Block pointers (!tt.ptr> or !tt.ptr) + else if (auto ptrType = llvm::dyn_cast(type)) { + if (auto tensorType = + llvm::dyn_cast(ptrType.getPointeeType())) { + // Each tensor of rank k gets k values for its offsets and k values for + // its strides, all of which has Index type. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else { + // The only relevant state that can be updated in loops for scalar + // pointers are offset. No need to include stride here. + offsetSegmentSize = 1; + } + } else { + return std::nullopt; + } + + return std::make_pair(offsetSegmentSize, strideSegmentSize); +} + +} // namespace tts +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..c7f1c8174 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,134 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +using namespace mlir; +using namespace linalg; +using namespace mlir::bufferization; + +// +// This file implements the bufferizable interface for TritonTilingExtOps. +// The interface is required for bufferization (i.e: converting from tensors to +// memrefs). +// Since the bufferization semantics of TritonTilingExtOps are identical to +// linalg ops, the implementation was borrowed almost verbatim from +// mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +// with the exception that the code to handle linalg's region has been removed. +// (the original implementation is in an anonymous namespace, so we cannot +// reuse) +// +namespace { + +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult bufferizeTritonTilingExtDestinationStyleOpInterface( + RewriterBase &rewriter, DestinationStyleOpInterface op, + const BufferizationOptions &options) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasPureTensorSemantics()) + return op->emitError() << "op does not have tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + clone(rewriter, op, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +template +struct TritonTilingExtOpInterface + : public DstBufferizableOpInterfaceExternalModel< + TritonTilingExtOpInterface, OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is read if it is used in the computation. + return cast(op).isDpsInput(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is written to if it is not an input/init. + return cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + return bufferizeTritonTilingExtDestinationStyleOpInterface( + rewriter, cast(op), options); + } +}; + +template struct TritonTilingExtOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; +} // namespace + +void mlir::ttx::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + // clang-format off + registry.addExtension(+[](MLIRContext *ctx, ttx::TritonTilingExtDialect *dialect) { + TritonTilingExtOpInterfaceHelper< + ttx::CumSumOp + >::registerOpInterface(ctx); + }); + // clang-format on +} diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt new file mode 100644 index 000000000..b6b07162c --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(TritonTilingExtIR + BufferizableOpInterfaceImpl.cpp + CumSum.cpp + TritonTilingExtDialect.cpp + + DEPENDS + TritonTilingExtInterfacesIncGen + TritonTilingExtOpsIncGen + + LINK_LIBS PUBLIC + TritonIR + MLIRAffineAnalysis + MLIRFuncDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgUtils + ) diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp new file mode 100644 index 000000000..619b653ea --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp @@ -0,0 +1,112 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// This file implements cumulative sum (CumSum) using the TilingInterface. Only +// supports tensors of rank 1 & 2 and axis == rank - 1 (i.e: we can split the +// computation of each row and compute them independently). The semantics of +// tiling for other axes are more complex and require working with +// non-contiguous memory. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "ttx-cumsum" + +using namespace mlir; +using namespace mlir::ttx; + +void ttx::CumSumOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input, IntegerAttr axis, Value output, + ArrayRef attributes) { + SmallVector inputs{input}; + SmallVector outputs{output}; + odsState.addOperands(inputs); + odsState.addOperands(outputs); + odsState.addAttribute( + "operand_segment_sizes", + odsBuilder.getDenseI32ArrayAttr({static_cast(inputs.size()), + static_cast(outputs.size())})); + + odsState.addAttribute(getAxisAttrStrName(), axis); + odsState.addAttributes(attributes); + odsState.addTypes(SmallVector{output.getType()}); +} + +mlir::LogicalResult ttx::CumSumOp::verify() { + auto inputType = getInput().getType(); + if (!isa(inputType) && !isa(inputType)) { + return emitOpError( + "CumSum op expects input to be either tensor or memref."); + } + + auto outputType = getOutput().getType(); + if (!isa(outputType) && !isa(outputType)) { + return emitOpError( + "CumSum op expects output to be either tensor or memref."); + } + + if (dyn_cast(inputType).getShape() != + dyn_cast(outputType).getShape()) { + return emitOpError("Input and output types must be the same."); + } + + int64_t rank = getRank(); + if (rank != 1 && rank != 2) { + return emitOpError("CumSum op only takes tensors of rank 1 & 2."); + } + + int64_t axis = getAxis(); + if (axis != rank - 1) { + return emitOpError("CumSum computation only supports axis == rank - 1"); + } + + return success(); +} + +AffineMap ttx::CumSumOp::getInputIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index == 0); + return AffineMap::getMultiDimIdentityMap(getRank(), context); +} + +AffineMap ttx::CumSumOp::getOutputIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index == 0); + return AffineMap::getMultiDimIdentityMap(getRank(), context); +} + +SmallVector ttx::CumSumOp::getLoopIteratorTypes() { + SmallVector iterators; + iterators.append(getRank() - 1, utils::IteratorType::parallel); + iterators.push_back(utils::IteratorType::reduction); + return iterators; +} + +SmallVector ttx::CumSumOp::getIterationDomain(OpBuilder &b) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(*this); + auto loc = getLoc(); + auto zero = b.getIndexAttr(0); + auto one = b.getIndexAttr(1); + SmallVector iterationDomain; + + // Return the bounds for all dimensions. The caller is responsible for not + // tiling the inner most dimension, otherwise the semantic of the resulting op + // is incorrect. + for (auto i = 0; i < getRank(); i++) { + OpFoldResult upperbound = linalg::createFoldedDimOp(b, loc, getInput(), i); + iterationDomain.push_back(Range{zero, upperbound, one}); + } + return iterationDomain; +} diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp new file mode 100644 index 000000000..da4d97674 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp @@ -0,0 +1,404 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Value.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ttx; +using namespace mlir::linalg; + +namespace mlir { +namespace ttx { + +Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Case([&](MemRefType type) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Default([&](Type t) { return nullptr; }); +} + +// +// getTiledImplementation +// +// Given an array of offsets and sizes, return the corresponding tiled version +// of the current op. +// +// This method is responsible for creating the extract slice ops for each +// operand of the op (including input and output operand). +// +// As an example, assuming we tile a linalg.matmul ins(%0, %1) out(%out) +// +// This method then generate: +// +// %in_slice_0 = extract_slice from %0 +// %in_slice_1 = extract_slice from %1 +// %out_slice = extract_slice from %out +// %tile = linalg.matmul ins(%in_slice_0, %in_slice_1) out(%out_slice) +// +// To generate these extract slice, we go over each operand, get the +// corresponding affine map to compute the correct offsets and sizes. +// +// Now let's describe how we compute the correct offsets and sizes from +// an affine map. +// +// - Offsets: +// An affine map describes how to access a tensor (i.e: the indicies into a +// tensor), so getting the offsets (also indices) from an affine map is just +// simply "applying" the sub-map on the offset (calling +// makeComposedFoldedAffineApply which also does constant folding +// automatically). +// +// For example: +// Let's assume we have the following nested loops: +// for i in range(0, 10): +// for j in range(0, 20): +// for k in range(0, 30): +// dst[i][j][k] = src[i * 2][j + k] +// +// Assume that we describe the iteration space based on dst. So: +// - dst's affine map is (d0, d1, d2) -> (d0, d1, d2) +// - src's affine map is (d0, d1, d2) -> (d0 * 2, d1 + d2) +// +// Now let's say we want to tile the operator with offset (0, 1, 2). +// +// For dst, we apply this (0, 1, 2) to its affine map and get (0, 1, 2) +// +// For src, we have to plug in the offsets into the affine map to get: +// +// (0 * 2, 1 + 2) = (0, 3) +// +// This is exactly what the implementation does as well. +// The call to getSubMap gets the i'th result expression, then the call to +// makeComposedFoldedAffineApply apply the `offsets` array to the i'th result +// expression in the affine map. +// +// +// - Sizes: +// Size is slightly more complex, notice that there are 3 steps to compute +// sizes: +// +// 1) call linalg::computeTileSizes on the provided `sizes` +// 2) apply the affine map +// 3) add 1 to the result +// +// The reason for this complexity is because the affine maps describe indices +// iteration space with a half open interval (i.e.: we always from 0 until right +// before the upper bound). So if we simply apply the affine map on the sizes, +// we will have incorrect results. +// +// Consider this snippet again: +// for i in range(0, 16): +// for j in range(0, 32): +// for k in range(0, 64): +// dst[i][j][k] = src[i * 2][j + k] +// +// Assume we want the operator to have a tile size of (16, 32, 64) -- so no +// tiling at all. If we apply the affine map of src (d0, d1, d2) -> (d0 * 2, d1 +// + d2), we have +// +// (16 * 2, 32 + 64) -> (32, 96) +// +// However, consider the second dimension of source: +// - j goes from 0 till 31 inclusive +// - k goes from 0 till 63 inclusive +// +// So the max index of src's second dimension is 31 + 63 = 94. Since index +// starts from 0, this means the second dimension has 95 elements. But the +// formula gives us a tile size of 96!!! The same argument can be applied for +// the first dimension as well, the number of elements is 15 * 2 + 1 = 31, but +// computed tile size is 32. +// +// So simply applying the indexing map to compute tile size is INCORRECT!! +// This happens because the indexing map operates on [0, size), while tile sizes +// are inclusive. +// +// The correct formula is: +// ((d0 - 1) * 2 + 1), (d1 - 1) + (d2 - 1) + 1 which gives +// (15 * 2 + 1, 32 - 1 + 64 - 1 + 1) -> (31, 95) +// +// So again, the steps are: +// - Subtract 1 from the sizes (what linalg::computeTileSizes does) +// - Apply the affine map +// - Add 1 to the result +// +template +FailureOr getTiledImplementation(TritonTilingExtOpTy op, + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) { + Location loc = op->getLoc(); + SmallVector valuesToTile = op->getOperands(); + SmallVector tiledValues; + auto oneAttr = b.getI64IntegerAttr(1); + + for (OpOperand &opOperand : op->getOpOperands()) { + unsigned int index = opOperand.getOperandNumber(); + auto val = valuesToTile[index]; + auto type = dyn_cast(val.getType()); + + if (!type) { + tiledValues.push_back(val); + continue; + } + + auto rank = type.getRank(); + SmallVector newOffsets; + SmallVector newSizes; + SmallVector newStrides(rank, oneAttr); + + llvm::SmallVector composedTileSizes = + linalg::computeTileSizes(b, loc, sizes, {}); + + AffineMap map = op.getIndexingMap(b.getContext(), index, sizes); + for (int64_t i = 0; i < rank; i++) { + AffineMap m = map.getSubMap(i); + { + OpFoldResult upperboundClosed = + affine::makeComposedFoldedAffineApply(b, loc, m, composedTileSizes); + AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + b, loc, s0 + 1, upperboundClosed); + newSizes.push_back(size); + } + { + OpFoldResult offset = + affine::makeComposedFoldedAffineApply(b, loc, m, offsets); + newOffsets.push_back(offset); + } + } + + tiledValues.push_back( + getSlice(b, loc, val, newOffsets, newSizes, newStrides)); + } + + SmallVector resultTensorTypes = llvm::to_vector( + llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { + return tiledValues[opOperand.getOperandNumber()].getType(); + })); + + Operation *tiledOp = clone(b, op, resultTensorTypes, tiledValues); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + +// +// getResultTilePosition +// This method returns the resultOffsets and resultSizes through references +// for the tiled operator. While `getTiledImplementation` is responsible for +// generating the extract slice for all operands, `getResultTilePosition` is +// responsible for returning the offsets and sizes which the tiling engine will +// then use to generate the corresponding insert_slice ops. +// +// Because the slice we insert back to the output tensor is the same as the +// slice that we extracted from the output tensor, this method just repeats the +// offset and size computation in `getTiledImplementation`. +// +template +LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + Location loc = op.getLoc(); + + AffineMap outputMap = + op.getOutputIndexingMap(b.getContext(), resultNumber, sizes); + + Value result = op.getDpsInitOperand(resultNumber)->get(); + auto rank = dyn_cast(result.getType()).getRank(); + + llvm::SmallVector composedTileSizes = + linalg::computeTileSizes(b, loc, sizes, {}); + for (int64_t i = 0; i < rank; i++) { + AffineMap m = outputMap.getSubMap(i); + { + OpFoldResult upperboundClosed = + affine::makeComposedFoldedAffineApply(b, loc, m, composedTileSizes); + AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + b, loc, s0 + 1, upperboundClosed); + resultSizes.push_back(size); + } + { + OpFoldResult offset = + affine::makeComposedFoldedAffineApply(b, loc, m, offsets); + resultOffsets.push_back(offset); + } + } + + return success(); +} + +// This method is borrowed verbatim from +// mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +// +// This is invoked when the current op produces a result that is used +// as an input to another op that is being tiled. The method essentially handles +// producing a new op where the result matches the given offsets and sizes. +// If the method succeeds, the two new operators will be fused in the same loop. +// +// As an example, consider the following IR where the linalg.generic is being +// tiled (unnecessary detailed omitted for brevity): +// +// clang-format: off +// +// func.func @some_op_1( +// %arg0: tensor<8x2x256x512xbf16>, +// %arg1: tensor<8x256x1024xbf16> +// ) -> tensor<8x256x1024xbf16> +// %1 = linalg.init_tensor [8, 256, 1024] : tensor<8x256x1024xbf16> +// %2 = linalg.init_tensor [8, 256, 1024] : tensor<8x256x1024xbf16> +// %3 = ttx.some_op +// ins(%arg0 : tensor<8x2x256x512xbf16>) +// outs(%1 : tensor<8x256x1024xbf16>) -> tensor<8x256x1024xbf16> +// %4 = linalg.generic +// ins(%3, %arg1 : tensor<8x256x1024xbf16>, tensor<8x256x1024xbf16>) +// outs(%2 : tensor<8x256x1024xbf16>) { +// ^bb0(%arg2: bf16, %arg3: bf16, %arg4: bf16): +// %add = arith.addf %arg2, %arg3 : bf16 +// linalg.yield %add : bf16 +// } -> tensor<8x256x1024xbf16> +// return %4 : tensor<8x256x1024xbf16> +// } +// +// clang-format: on +// +// We tile linalg.generic, but one of its inputs is %3 which is the result of +// ttx.some_op. So the tiling engine will invoke +// generateResultTileValue of ttx.some_op to determine if it's +// possible to create a tiled version of it, thereby making it possible to fuse +// both operators together in a loop. +template +FailureOr +generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes) { + + // Check that the indexing map used for the output is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the result to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = op.getOutputIndexingMap(b.getContext(), 0, sizes); + if (!indexingMap.isProjectedPermutation()) { + return op.emitOpError( + "unhandled tiled implementation generation when result is not " + "accessed using a permuted projection"); + } + + auto numLoops = op.getLoopIteratorTypes().size(); + SmallVector iterationTileOffsets(numLoops), + iterationTileSizes(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = op.getIterationDomain(b); + for (auto range : llvm::enumerate(iterationDomain)) { + iterationTileOffsets[range.index()] = range.value().offset; + iterationTileSizes[range.index()] = range.value().size; + } + } + for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) { + assert(resultExpr.value().getKind() == AffineExprKind::DimId); + // HACK: LLVM casting utilities do not work here for out-of-tree builds, + // as there is no template specialization for this cast in the base + // build. + AffineDimExpr affineDimExpr(static_cast( + const_cast(resultExpr.value().getAsOpaquePointer()))); + unsigned dimPosition = affineDimExpr.getPosition(); + iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; + iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; + } + + FailureOr tilingResult = + op.getTiledImplementation(b, iterationTileOffsets, iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) + return op.emitOpError("failed to generate tiled implementation"); + + return TilingResult{ + tilingResult->tiledOps, + SmallVector{tilingResult->tiledValues[resultNumber]}}; +} + +// This method is borrowed directly from linalg.generic's implementation +// in mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +// This marks all operands that are part of the input group to have read +// effect, while all other operands that are part of the output group +// to have both read and write effects. +static void getTritonTilingExtEffectsImpl( + SmallVectorImpl> + &effects, + ValueRange results, ArrayRef inputOperands, + const MutableOperandRange &outputOperands) { + for (auto operand : inputOperands) { + if (!llvm::isa(operand->get().getType())) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + for (auto &operand : outputOperands) { + if (!llvm::isa(operand.get().getType())) + continue; + + effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } +} + +template +void getEffects( + TritonTilingExtOpTy op, + SmallVectorImpl> + &effects) { + getTritonTilingExtEffectsImpl(effects, op.getOperation()->getResults(), + op.getDpsInputOperands(), + op.getDpsInitsMutable()); +} + +} // namespace ttx +} // namespace mlir + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void TritonTilingExtDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.cpp.inc" + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.cpp.inc" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt new file mode 100644 index 000000000..58039db5f --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(Tx81IR + IR/Tx81Dialect.cpp + IR/Tx81Ops.cpp + + DEPENDS + Tx81TableGen + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp new file mode 100644 index 000000000..d819a4e17 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp @@ -0,0 +1,30 @@ +//===-------------------------- Tx81Dialect.cpp ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +using namespace mlir; +using namespace mlir::tx; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void Tx81Dialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.cpp.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.cpp.inc" + +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp new file mode 100644 index 000000000..9db877dce --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp @@ -0,0 +1,10 @@ +//===-------------------------- Tx81Ops.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h" +using namespace mlir; +using namespace mlir::tx; diff --git a/third_party/tsingmicro/name.conf b/third_party/tsingmicro/name.conf new file mode 100644 index 000000000..1340763be --- /dev/null +++ b/third_party/tsingmicro/name.conf @@ -0,0 +1 @@ +tsingmicro diff --git a/third_party/tsingmicro/python/triton_tsingmicro.cc b/third_party/tsingmicro/python/triton_tsingmicro.cc new file mode 100644 index 000000000..608918898 --- /dev/null +++ b/third_party/tsingmicro/python/triton_tsingmicro.cc @@ -0,0 +1,45 @@ +#include + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" +#include "triton-shared/Conversion/TritonToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" +#include "magic-kernel/Conversion/LinalgToMK/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h" + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" + +namespace py = pybind11; +using namespace mlir; + +void init_triton_tsingmicro(py::module &&m) {} diff --git a/third_party/tsingmicro/scripts/build_llvm.sh b/third_party/tsingmicro/scripts/build_llvm.sh new file mode 100755 index 000000000..d76a4ef1c --- /dev/null +++ b/third_party/tsingmicro/scripts/build_llvm.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +if [ -z "${LLVM_PROJECT+x}" ]; then + echo "Please set the environment variable “LLVM_PROJECT”." 1>&2 + exit 1 +fi + +if [ ! -d $LLVM_PROJECT ]; then + echo "Error: $LLVM_PROJECT not exist!" 1>&2 + exit 1 +fi + +BUILD_TYPE=Release + +build_llvm() { + mkdir $LLVM_PROJECT/build + cd $LLVM_PROJECT/build + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS="clang;mlir;llvm;lld" \ + -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU;RISCV" \ + -DLLVM_USE_LINKER=lld \ + -DMLIR_ENABLE_BINDINGS_PYTHON=1 \ + ../llvm + ninja +} + +build_llvm diff --git a/third_party/tsingmicro/scripts/build_tsingmicro.sh b/third_party/tsingmicro/scripts/build_tsingmicro.sh new file mode 100755 index 000000000..1e093cf75 --- /dev/null +++ b/third_party/tsingmicro/scripts/build_tsingmicro.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +if [ -z "${WORKSPACE+x}" ]; then + WORKSPACE=$(realpath "$project_dir/..") +fi + +TX8_HOME=$WORKSPACE/tx8_deps +LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 + +if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.flagtree/tsingmicro/" + TX8_HOME=$WORKSPACE/tx8_deps + LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 +fi + +if [ ! -d $TX8_HOME ]; then + echo "Error: $TX8_HOME not exist!" 1>&2 + exit 1 +fi + +if [ ! -d $LLVM ]; then + echo "Error: $LLVM not exist!" 1>&2 + exit 1 +fi + +BUILD_TYPE=Release + +export TX8_HOME=$TX8_HOME +export LLVM_SYSPATH=$LLVM +export FLAGTREE_BACKEND=tsingmicro + +export TRITON_OFFLINE_BUILD=ON +export TRITON_BUILD_WITH_CLANG_LLD=true +export TRITON_BUILD_WITH_CCACHE=true +export TRITON_BUILD_PROTON=OFF + +echo "export TX8_HOME=$TX8_HOME" +echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +echo "export FLAGTREE_BACKEND=$FLAGTREE_BACKEND" + +echo "export TRITON_OFFLINE_BUILD=$TRITON_OFFLINE_BUILD" +echo "export TRITON_BUILD_WITH_CLANG_LLD=$TRITON_BUILD_WITH_CLANG_LLD" +echo "export TRITON_BUILD_WITH_CCACHE=$TRITON_BUILD_WITH_CCACHE" +echo "export TRITON_BUILD_PROTON=$TRITON_BUILD_PROTON" + +cd python +python3 -m pip install . --no-build-isolation -v --verbose diff --git a/third_party/tsingmicro/scripts/install.sh b/third_party/tsingmicro/scripts/install.sh new file mode 100755 index 000000000..b0d3346b4 --- /dev/null +++ b/third_party/tsingmicro/scripts/install.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +apt install git +apt install lld + +pip uninstall triton + +pip install gitpython +pip install torch==2.7.0 torchvision diff --git a/third_party/tsingmicro/scripts/run_tsingmicro.sh b/third_party/tsingmicro/scripts/run_tsingmicro.sh new file mode 100755 index 000000000..13e3ed38c --- /dev/null +++ b/third_party/tsingmicro/scripts/run_tsingmicro.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +if [ -z "${WORKSPACE+x}" ]; then + WORKSPACE=$(realpath "$project_dir/..") +fi + +TX8_HOME=$WORKSPACE/tx8_deps +LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 + +if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.flagtree/tsingmicro/" + TX8_HOME=$WORKSPACE/tx8_deps + LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 +fi + +if [ ! -d $TX8_HOME ]; then + echo "Error: $TX8_HOME not exist!" 1>&2 + exit 1 +fi + +if [ ! -d $LLVM ]; then + echo "Error: $LLVM not exist!" 1>&2 + exit 1 +fi + +export TX8_HOME=$TX8_HOME +export LLVM_SYSPATH=$LLVM +export LD_LIBRARY_PATH=$TX8_HOME/lib:$LD_LIBRARY_PATH +export TRITON_ALWAYS_COMPILE=1 + +# export TRITON_DUMP_PATH=$project_dir/dump + +echo "export TX8_HOME=$TX8_HOME" +echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH" +echo "export TRITON_ALWAYS_COMPILE=$TRITON_ALWAYS_COMPILE" + +python3 $@