diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index 2e504494b..95965d25b 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -29,4 +29,4 @@ jobs: shell: bash run: | source /usr/local/Ascend/ascend-toolkit/set_env.sh - python3.9 third_party/ascend/python/tutorials/01-vector-add.py + python3.9 third_party/tests/ascend/vector-add.py diff --git a/.gitignore b/.gitignore index 8fa02d10d..ff36e6a28 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ third_party/cambricon/ third_party/iluvatar/iluvatarTritonPlugin.so third_party/triton_shared/ third_party/xpu/backend/xpu3 +third_party/ascend # Proton python/triton/profiler diff --git a/python/setup.py b/python/setup.py index 7c5da7029..c962212eb 100644 --- a/python/setup.py +++ b/python/setup.py @@ -28,7 +28,7 @@ import pybind11 -import setup_helper as helper +from setup_tools import setup_helper as helper @dataclass @@ -423,6 +423,7 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + cmake_args += helper.get_backend_cmake_args(build_ext=self) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -487,6 +488,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) + helper.install_extension(build_ext=self) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") @@ -611,7 +613,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) - helper.post_install(self) + helper.post_install() class plugin_develop(develop): @@ -619,7 +621,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) - helper.post_install(self) + helper.post_install() class plugin_bdist_wheel(bdist_wheel): @@ -627,7 +629,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) - helper.post_install(self) + helper.post_install() class plugin_egginfo(egg_info): @@ -635,7 +637,7 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) - helper.post_install(self) + helper.post_install() package_data_tools = helper.get_package_data_tools() diff --git a/python/setup_tools/__init__.py b/python/setup_tools/__init__.py new file mode 100644 index 000000000..c3411313f --- /dev/null +++ b/python/setup_tools/__init__.py @@ -0,0 +1,4 @@ +from . import setup_helper +from . import utils + +__all__ = ["setup_helper", "utils"] diff --git a/python/setup_helper.py b/python/setup_tools/setup_helper.py similarity index 73% rename from python/setup_helper.py rename to python/setup_tools/setup_helper.py index 6c0959031..35e8581ee 100644 --- a/python/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -8,35 +8,19 @@ import urllib.request from pathlib import Path import hashlib -from dataclasses import dataclass from distutils.sysconfig import get_python_lib +from . import utils -use_triton_shared = False -necessary_third_party = ["triton_shared"] -default_backends = ["nvidia", "amd"] extend_backends = [] +default_backends = ["nvidia", "amd"] plugin_backends = ["cambricon", "ascend"] ext_sourcedir = "triton/_C/" flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() +offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF") device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} - - -@dataclass -class FlagTreeBackend: - name: str - url: str - tag: str - - -flagtree_backend_info = { - "triton_shared": - FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", - tag="380b87122c88af131530903a702d5318ec59bb33"), - "cambricon": - FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", - tag="00f51c2e48a943922f86f03d58e29f514def646d"), -} +flagtree_backends = utils.flagtree_backends +backend_utils = utils.activate(flagtree_backend) set_llvm_env = lambda path: set_env({ 'LLVM_INCLUDE_DIRS': Path(path) / "include", @@ -45,48 +29,98 @@ class FlagTreeBackend: }) +def install_extension(*args, **kargs): + try: + backend_utils.install_extension(*args, **kargs) + except Exception: + pass + + +def get_backend_cmake_args(*args, **kargs): + try: + return backend_utils.get_backend_cmake_args(*args, **kargs) + except Exception: + return [] + + def get_device_name(): return device_mapping[flagtree_backend] def get_extra_packages(): packages = [] - if flagtree_backend == 'ascend': - packages = [ - "triton/triton_patch", - "triton/triton_patch/language", - "triton/triton_patch/compiler", - "triton/triton_patch/runtime", - ] + try: + packages = backend_utils.get_extra_install_packages() + except Exception: + packages = [] return packages def get_package_data_tools(): package_data = ["compile.h", "compile.c"] - if flagtree_backend == 'xpu': - package_data += ["compile_xpu.h", "compile_xpu.c"] + try: + package_data += backend_utils.get_package_data_tools() + except Exception: + package_data return package_data -def post_install(self): - - def get_module(module_path): - import importlib.util - import os - module_path = os.path.abspath(module_path) - spec = importlib.util.spec_from_file_location("module", module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - def ascend(): - utils = get_module("../third_party/ascend/utils.py") - utils.post_install() - - code = f"{flagtree_backend}()" +def git_clone(lib, lib_path): + import git + MAX_RETRY = 4 + print(f"Clone {lib.name} into {lib_path} ...") + retry_count = MAX_RETRY + while (retry_count): + try: + repo = git.Repo.clone_from(lib.url, lib_path) + if lib.tag is not None: + repo.git.checkout(lib.tag) + sub_triton_path = Path(lib_path) / "triton" + if os.path.exists(sub_triton_path): + shutil.rmtree(sub_triton_path) + print(f"successfully clone {lib.name} into {lib_path} ...") + return True + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") + return False + + +def dir_rollback(deep, base_path): + while (deep): + base_path = os.path.dirname(base_path) + deep -= 1 + return Path(base_path) + + +def download_flagtree_third_party(name, condition, required=False, hock=None): + if not condition: + return + backend = None + for _backend in flagtree_backends: + if _backend.name in name: + backend = _backend + break + if backend is None: + return backend + base_dir = dir_rollback(3, __file__) / "third_party" + prelib_path = Path(base_dir) / name + lib_path = Path(base_dir) / _backend.name + + if not os.path.exists(prelib_path) and not os.path.exists(lib_path): + succ = git_clone(lib=backend, lib_path=prelib_path) + if not succ and required: + raise RuntimeError("Bad network ! ") + if callable(hock): + hock(third_party_base_dir=base_dir, backend=backend) + else: + print(f'Found third_party {backend.name} at {lib_path}\n') + + +def post_install(): try: - exec(code, globals(), locals()) - except: #noqa: E722 + backend_utils.post_install() + except Exception: pass @@ -256,12 +290,13 @@ class CommonUtils: @staticmethod def unlink(): - cur_path = os.path.dirname(__file__) + cur_path = dir_rollback(2, __file__) if "editable_wheel" in sys.argv: installation_dir = cur_path else: installation_dir = get_python_lib() backends_dir_path = Path(installation_dir) / "triton" / "backends" + # raise RuntimeError(backends_dir_path) if not os.path.exists(backends_dir_path): return for name in os.listdir(backends_dir_path): @@ -279,10 +314,10 @@ def unlink(): def skip_package_dir(package): if 'backends' in package or 'profiler' in package: return True - if flagtree_backend in ['cambricon']: - if package not in ['triton', 'triton/_C']: - return True - return False + try: + return backend_utils.skip_package_dir(package) + except Exception: + return False @staticmethod def get_package_dir(packages): @@ -296,62 +331,12 @@ def get_package_dir(packages): pair = (package, f"{backend_triton_path}{package}") connection.append(pair) package_dict.update(connection) - if flagtree_backend == "ascend": - triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" - package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" - package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" - package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" - package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + try: + package_dict.update(backend_utils.get_package_dir()) + except Exception: + pass return package_dict - @staticmethod - def download_third_party(): - import git - MAX_RETRY = 4 - global use_triton_shared, flagtree_backend - third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party" - - def git_clone(lib, lib_path): - global use_triton_shared - print(f"Clone {lib.name} into {lib_path} ...") - retry_count = MAX_RETRY - while (retry_count): - try: - repo = git.Repo.clone_from(lib.url, lib_path) - repo.git.checkout(lib.tag) - if lib.name in flagtree_backend_info: - sub_triton_path = Path(lib_path) / "triton" - if os.path.exists(sub_triton_path): - shutil.rmtree(sub_triton_path) - print(f"successfully clone {lib.name} into {lib_path} ...") - return - except Exception: - retry_count -= 1 - print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") - - print(f"Unable to clone third_party {lib.name}") - if lib.name in necessary_third_party: - use_triton_shared = False - print("\n\ttriton_shared is compiled by default, but for " - "some reason we couldn't download triton_shared\n" - "as third_party (most likely for network reasons), " - "so we couldn't compile triton_shared\n") - - third_partys = [] - if os.environ.get("USE_TRITON_SHARED", "ON") == "ON" and not flagtree_backend: - third_partys.append(flagtree_backend_info["triton_shared"]) - else: - use_triton_shared = False - if flagtree_backend in flagtree_backend_info: - third_partys.append(flagtree_backend_info[flagtree_backend]) - - for lib in third_partys: - lib_path = Path(third_party_base_dir) / lib.name - if not os.path.exists(lib_path): - git_clone(lib=lib, lib_path=lib_path) - else: - print(f'Found third_party {lib.name} at {lib_path}\n') - def handle_flagtree_backend(): global ext_sourcedir @@ -360,8 +345,6 @@ def handle_flagtree_backend(): extend_backends.append(flagtree_backend) if "editable_wheel" in sys.argv and flagtree_backend != "ascend": ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - if use_triton_shared and not flagtree_backend: - default_backends.append("triton_shared") def set_env(env_dict: dict): @@ -373,8 +356,15 @@ def check_env(env_val): return os.environ.get(env_val, '') != '' -CommonUtils.download_third_party() +download_flagtree_third_party("triton_shared", condition=(not flagtree_backend)) + +download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"), + hock=utils.ascend.precompile_hock, required=True) + +download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True) + handle_flagtree_backend() + cache = FlagTreeCache() # iluvatar diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py new file mode 100644 index 000000000..75899a929 --- /dev/null +++ b/python/setup_tools/utils/__init__.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from pathlib import Path +import importlib.util +import os +from . import ascend + + +@dataclass +class FlagTreeBackend: + name: str + url: str + tag: str = None + + +flagtree_backends = ( + FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", + tag="380b87122c88af131530903a702d5318ec59bb33"), + FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", + tag="00f51c2e48a943922f86f03d58e29f514def646d"), + FlagTreeBackend( + name="ascend", + url="https://gitee.com/ascend/triton-ascend.git", + ), +) + + +def activate(backend, suffix=".py"): + if not backend: + return + module_path = Path(os.path.dirname(__file__)) / backend + module_path = str(module_path) + suffix + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +__all__ = ["ascend"] diff --git a/third_party/ascend/utils.py b/python/setup_tools/utils/ascend.py similarity index 59% rename from third_party/ascend/utils.py rename to python/setup_tools/utils/ascend.py index 1b92e492f..bb0184ec6 100644 --- a/third_party/ascend/utils.py +++ b/python/setup_tools/utils/ascend.py @@ -1,5 +1,34 @@ import os import shutil +from pathlib import Path + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=" + src_ext_path, + ] + + +def install_extension(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.join(os.path.abspath(os.path.dirname(src_ext_path)), "triton-adapter-opt") + python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt") + shutil.copy(src_ext_path, dst_ext_path) + + +def get_package_dir(): + package_dict = {} + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dict def insert_at_file_start(filepath, import_lines): @@ -49,10 +78,10 @@ def append_at_file_end(filepath, import_lines): def post_install(): import site - install_dir = site.getsitepackages() + install_dir = site.getsitepackages()[0] install_dir = os.path.join(install_dir, "triton") init_path = os.path.join(install_dir, "__init__.py") - patched_content = f""" + patched_content = """ import sys from .triton_patch.language import _utils as ascend_utils sys.modules['triton.language._utils'] = ascend_utils @@ -69,7 +98,7 @@ def post_install(): """ insert_at_file_start(init_path, patched_content) - content_to_append = f""" + content_to_append = """ from .triton_patch.language.core import dot, gather, insert, subview from .triton_patch.language.standard import flip from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil @@ -150,3 +179,56 @@ def get_ascend_patch_package_dir(backends): package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" return package_dir + + +def get_extra_install_packages(): + return [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + + +def precompile_hock(*args, **kargs): + third_party_base_dir = Path(kargs['third_party_base_dir']) + ascend_path = Path(third_party_base_dir) / "ascend" + patch_path = Path(ascend_path) / "triton_patch" + project_path = Path(third_party_base_dir) / "triton_ascend" + if os.path.exists(ascend_path): + shutil.rmtree(ascend_path) + if not os.path.exists(project_path): + raise RuntimeError(f"{project_path} can't be found. It might be due to a network issue") + ascend_src_path = Path(project_path) / "ascend" + patch_src_path = Path(project_path) / "triton_patch" + shutil.copytree(ascend_src_path, ascend_path, dirs_exist_ok=True) + shutil.copytree(patch_src_path, patch_path, dirs_exist_ok=True) + shutil.rmtree(project_path) + patched_code = """ set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") """ + src_code = """set(triton_abs_dir""" + + filepath = Path(patch_path) / "include" / "triton" / "Dialect" / "Triton" / "IR" / "CMakeLists.txt" + try: + import tempfile + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tmp_file: + with open(filepath, 'r') as file: + lines = file.readlines() + for line in lines: + if src_code in line: + tmp_file.writelines(patched_code) + else: + tmp_file.writelines(line) + backup_path = str(filepath) + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False diff --git a/python/setup_tools/utils/cambricon.py b/python/setup_tools/utils/cambricon.py new file mode 100644 index 000000000..3246fc695 --- /dev/null +++ b/python/setup_tools/utils/cambricon.py @@ -0,0 +1,4 @@ +def skip_package_dir(package): + if package not in ['triton', 'triton/_C']: + return True + return False diff --git a/python/setup_tools/utils/xpu.py b/python/setup_tools/utils/xpu.py new file mode 100644 index 000000000..92424b1b2 --- /dev/null +++ b/python/setup_tools/utils/xpu.py @@ -0,0 +1,2 @@ +def get_package_data_tools(): + return ["compile_xpu.h", "compile_xpu.c"] diff --git a/third_party/ascend/.gitignore b/third_party/ascend/.gitignore deleted file mode 100644 index c557aabf6..000000000 --- a/third_party/ascend/.gitignore +++ /dev/null @@ -1 +0,0 @@ -triton-adapter-opt diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt deleted file mode 100644 index 3321942c7..000000000 --- a/third_party/ascend/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_subdirectory(triton-adapter triton-adapter) - -add_triton_plugin(TritonHUAWEI ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp) - -# Copy triton-adapter-opt to python files -add_custom_target(COPY_TRITON_ADAPTER_OPT) -add_custom_command(TARGET COPY_TRITON_ADAPTER_OPT POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - $ - ${TRITON_ROOT_DIR}/python/triton/backends/ascend/triton-adapter-opt - DEPENDS triton-adapter-opt) -add_dependencies(TritonHUAWEI COPY_TRITON_ADAPTER_OPT) diff --git a/third_party/ascend/backend/__init__.py b/third_party/ascend/backend/__init__.py deleted file mode 100644 index 0eec99724..000000000 --- a/third_party/ascend/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py deleted file mode 100644 index dc13b8dad..000000000 --- a/third_party/ascend/backend/compiler.py +++ /dev/null @@ -1,329 +0,0 @@ -from triton.backends.compiler import BaseBackend, GPUTarget -from triton._C.libtriton import ir, passes -from triton.runtime import driver -from triton.runtime.cache import get_dump_manager -from dataclasses import dataclass -import functools -from typing import Any, Union, Tuple, Dict -from types import ModuleType -from pathlib import Path -import tempfile -import os -import subprocess -import hashlib -import ctypes -from typing import Optional - -from triton.backends.ascend.utils import downgrade_llir, _get_llvm_path, _get_mlir_path, _get_triton_adapter_opt_path, \ - _get_kernel_target, _get_npucompiler_path, _is_ascend_sanitizer_enabled - - -# TODO: materialize the concrete min shape -def min_dot_size(target: GPUTarget): - # return lambda lhsType, rhsType: (16, 16, 16) - return lambda lhsType, rhsType: (1, 1, 1) - - -def make_ttir(mod, metadata, opt): - if 'hash' not in metadata: - metadata['hash'] = hashlib.md5(f"{mod}-{metadata}".encode()).hexdigest() - # the same optimize pass for triton-ir as all other backends - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.common.add_inliner(pm) - passes.ttir.add_combine(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_reorder_broadcast(pm) - passes.common.add_cse(pm) - passes.common.add_licm(pm) - passes.common.add_symbol_dce(pm) - pm.run(mod) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - print(f"Dumping intermediate results to {dump_manager.cache_dir}") - dump_manager.put(str(mod), "kernel.ttir.mlir", binary=False) - - return mod - - -def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): - # use triton_adapter to lower Triton-MLIR to linalg - # Get Triton-MLIR as string - ttir_code = str(mod) - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "kernel.ttir.mlir") - dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") - Path(src_path).write_text(ttir_code) - triton_adapter_opt_path = _get_triton_adapter_opt_path() - - cmd_list = [ - triton_adapter_opt_path, src_path, f'--triton-to-linalg=global-kernel=false named-ops={named_ops}', "-o", - dst_path - ] - if _is_ascend_sanitizer_enabled(): - cmd_list += ["--mlir-print-debuginfo"] # pass debug info - - ret = subprocess.run(cmd_list, capture_output=True, check=True) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(Path(dst_path).read_text(), "kernel.ttadapter.mlir", binary=False) - - return Path(dst_path).read_text() - - -def linalg_to_llir(linalg: str, metadata, opt): - with tempfile.TemporaryDirectory() as tmpdir: - ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") - llmlir_path = os.path.join(tmpdir, "kernel.llir.mlir") - llir_path = os.path.join(tmpdir, "kernel.ll") - Path(ttadapter_path).write_text(linalg) - mlir_opt_path = _get_mlir_path("bin", "mlir-opt") - # TritonAdapter-MLIR to LLVM-MLIR - subprocess.check_call([ - mlir_opt_path, ttadapter_path, "--convert-linalg-to-affine-loops", "--eliminate-empty-tensors", - "--empty-tensor-to-alloc-tensor", "--one-shot-bufferize=allow-return-allocs-from-loops=true", - "--lower-affine", "--convert-linalg-to-loops", "--convert-scf-to-cf", "--convert-cf-to-llvm", - "--convert-arith-to-llvm", "--convert-math-to-llvm", "--convert-complex-to-llvm", - "--convert-vector-to-llvm", "--convert-index-to-llvm", "--memref-expand", "--expand-strided-metadata", - "--finalize-memref-to-llvm", "--convert-func-to-llvm", - # Lowering memrefs creates more affine.apply ops. - # Lowering these affine ops again creates further arith ops, - # so we have to run these two passes again here. - "--lower-affine", "--convert-arith-to-llvm", - # Remove all unrealized casts created - "--reconcile-unrealized-casts", "-o", llmlir_path - ]) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(Path(llmlir_path).read_text(), "kernel.llir.mlir", binary=False) - - # LLVM-MLIR to LLVM-IR - mlir_translate_path = _get_mlir_path("bin", "mlir-translate") - subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(Path(llir_path).read_text(), "kernel.ll", binary=False) - - return Path(llir_path).read_text() - - -def llir_to_cpuasm(llir: str, metadata, opt): - # add metadata at final stage - # Note: Compiled Kernel requires to estimate size of shared memory to occupy - # Currently, CPU backend requires no limit on shared memory size - metadata['shared'] = 1 - # We can get a function name (C naming) from - # LLVM-IR by getting the first "define void @". - fn_name = llir.split("define void @")[1].split("(")[0].strip() - metadata['name'] = fn_name + " cpu" - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "kernel.ll") - linked_path = os.path.join(tmpdir, "kernel_linked.ll") - dst_path = os.path.join(tmpdir, "kernel.s") - - llir = downgrade_llir(llir) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(llir, "kernel_downgrade.ll", binary=False) - - Path(src_path).write_text(llir) - - linker_path = _get_llvm_path("bin", "llvm-link") - libclc_path = _get_llvm_path("lib", "clc", "libspirv-aarch64--.bc") - subprocess.check_call([linker_path, src_path, libclc_path, "--only-needed", "-S", "-o", linked_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(Path(linked_path).read_text(), "kernel_linked.ll", binary=False) - - llc_path = _get_llvm_path("bin", "llc") - subprocess.check_call([llc_path, linked_path, "-o", dst_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata['hash']) - dump_manager.put(Path(dst_path).read_text(), "kernel.s", binary=False) - - # Actually it's text-format assembly. Use read_text(). - return Path(dst_path).read_text() - - -def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): - import re - # Note: Compiled Kernel requires to estimate size of shared memory to occupy - # Currently, NPU backend does not limit on shared memory - metadata['shared'] = 1 - # the mix mode is also encoded into metadata['name'] for runtime to distinguish - metadata['mix_mode'] = re.search(r'mix_mode\s*=\s*"([^"]+)"', linalg).group(1) - metadata['kernel_name'] = re.search(r'func\.func\s+@(\w+)', linalg).group(1) - # Use while space to split kernel_name and mix_mode. - # Check the function load_binary in npu_driver.py. - metadata['name'] = metadata['kernel_name'] + " " + metadata['mix_mode'] - # remove the mix_mode attribute - linalg = re.sub(r', mix_mode\s*=\s*"[^"]*"', '', linalg) - with tempfile.TemporaryDirectory() as tmpdir: - ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") - Path(ttadapter_path).write_text(linalg) - bin_file = os.path.join(tmpdir, "kernel") - bin_path = os.path.join(tmpdir, "kernel_reloc.o") - callback_path = os.path.join(tmpdir, "libkernel.so") - multibuffer = metadata['multibuffer'] - _compile_option_list = [ - f"--enable-auto-multi-buffer={multibuffer}", - ] - - if _is_ascend_sanitizer_enabled(): - _compile_option_list += ["--enable-sanitizer=true"] - npu_compiler_path = _get_npucompiler_path() - if (npu_compiler_path.endswith("bishengir-compile")): - _compile_option_list += [ - "--enable-hfusion-compile=true", - "--enable-hivm-compile=true", - "--enable-triton-kernel-compile=true", - ] - cmd_list = [npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file] - ret = subprocess.run(cmd_list, capture_output=True, check=True) - if Path(callback_path).is_file(): - lib = ctypes.CDLL(callback_path) - callback_func = getattr(lib, metadata['kernel_name'] + "_infer_workspace_shape_function") - callback_func.restype = ctypes.c_int64 - callback_func.argtypes = [] - metadata['workspace_size'] = callback_func() - - return Path(bin_path).read_bytes() - - -@dataclass(frozen=True) -class NPUOptions: - debug: bool = False - sanitize_overflow: bool = True - llvm_version: int = 15 - kernel_name: str = "triton_" - - cluster_dims: tuple = (1, 1, 1) - num_warps: int = -1 - num_ctas: int = -1 - num_stages: int = 2 - num_buffers_warp_spec: int = 0 - num_consumer_groups: int = 0 - reg_dec_producer: int = 0 - reg_inc_consumer: int = 0 - - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") - enable_npu_compile: bool = True - max_num_imprecise_acc_default: bool = None - extern_libs: dict = None - - multibuffer: bool = True - - def hash(self): - key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) - return hashlib.md5(key.encode("utf-8")).hexdigest() - - -@dataclass(frozen=True) -class CPUOptions: - debug: bool = False - llvm_version: int = 15 - kernel_name: str = "triton_" - - cluster_dims: tuple = (1, 1, 1) - num_warps: int = -1 - num_ctas: int = -1 - num_stages: int = -1 - - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: bool = None - extern_libs: dict = None - - def hash(self): - key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) - return hashlib.md5(key.encode("utf-8")).hexdigest() - - -class HuaweiBackend(BaseBackend): - - @staticmethod - def supports_target(target: GPUTarget): - return target.backend == 'cpu' or target.backend == 'npu' - - def __init__(self, target: GPUTarget) -> None: - super().__init__(target) - if (target.backend == "cpu"): - self.binary_ext = "cpuasm" - elif (target.backend == "npu"): - self.binary_ext = "npubin" - - def parse_options(self, opts) -> Any: - # TODO: get available targets when building options? - if self.target.backend == 'npu': - args = {k: opts[k] for k in NPUOptions.__dataclass_fields__.keys() if k in opts} - options = NPUOptions(**args) - else: - args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} - options = CPUOptions(**args) - return options - - def pack_metadata(self, metadata): - from triton.backends.ascend.utils import TRITON_PROFILER_REGISTERED - # collect necessary metadata to launch kernels - # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. - # Get this name as the kernel_name to CANN runtime. - # kernel_name is unique to Huawei backend and should not be public. - # CANN runtime limits the length of kernel name <= 50. - # Considering '\n' is appended, thus the real kernel name <= 49. - KERNEL_NAME_MAX_LEN = 49 - kernel_name_orig, mix_mode = metadata.name.split() - if (len(kernel_name_orig) > KERNEL_NAME_MAX_LEN): - kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] - # import warnings - # # red = "\x1b[31;20m" - # # reset = "\x1b[0m" - # warnings.warn(kernel_name_orig + " is truncated to " + kernel_name) - # warnings.warn("because '" + kernel_name_orig + "' exceeds torchnpu profiler's length limit < 50") - else: - kernel_name = kernel_name_orig - return { - "kernel_name": kernel_name, - "hash": metadata.hash, - "debug": metadata.debug, - "profiler_registered": TRITON_PROFILER_REGISTERED, - } - - def get_codegen_implementation(self): - # Note: a dict of functions is required to generate vendor-specific code piecies - # e.g. convert custom types like fp8e4b15 - codegen_fns = {"min_dot_size": min_dot_size(self.target)} - return codegen_fns - - def load_dialects(self, ctx): - pass - - def add_stages(self, stages, options): - if self.target.backend == 'npu': - stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) - if options.enable_npu_compile: - stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options, named_ops=True) - stages["npubin"] = lambda src, metadata: linalg_to_bin_enable_npu_compile(src, metadata, options) - else: - pass - else: - stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) - stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options) - stages["llir"] = lambda src, metadata: linalg_to_llir(src, metadata, options) - stages["cpuasm"] = lambda src, metadata: llir_to_cpuasm(src, metadata, options) - - @functools.lru_cache() - def hash(self): - # TODO fetch compiler version - version_key = self.target - return str(version_key) - - def get_module_map(self) -> Dict[str, ModuleType]: - return {} diff --git a/third_party/ascend/backend/cpu_driver.py b/third_party/ascend/backend/cpu_driver.py deleted file mode 100644 index ec86ec01c..000000000 --- a/third_party/ascend/backend/cpu_driver.py +++ /dev/null @@ -1,185 +0,0 @@ -from triton.runtime.cache import get_cache_manager, get_dump_manager -from pathlib import Path -import tempfile -import os -import sysconfig -import subprocess -import importlib -from triton.backends.ascend.utils import _get_llvm_path - - -# TODO: temporarily fake CPUUtils class -class CPUUtils(object): - - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(CPUUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - pass - - def get_device_properties(self, device): - # temperoarily added properties to avoid triton-compiler complain - # fetch available memory at runtime - return {"max_shared_mem": 1} - - def load_binary(self, name, kernel, shared, device): - # TODO (temperoarily fake function) load a binary from binary object to device - # return value are: (mod, funcptr/handle, n_regs, n_spills) - return None, kernel, 0, 0 - - -class CPULauncher(object): - - def __init__(self, src, metadata): - kernel_name = metadata.name.split()[0] - signature = src.signature - constants = src.constants - launcher_src = generate_cpu_wrapper_src(constants, signature, kernel_name) - self.launch = compile_module(launcher_src) - - def __call__(self, *args, **kwargs): - self.launch(*args, **kwargs) - - -class CPUDriver: - - def __init__(self): - self.utils = CPUUtils() - self.launcher_cls = CPULauncher - super().__init__() - - def get_current_target(self): - # TODO: do we rely on CPU arch? - return ("cpu", "arm-64") - - def get_current_device(self): - """ - Get current device - """ - # TODO: dummy device-getter for cpu backend - return 0 - - def set_current_device(self, device): - """ - Set current device as the given device - """ - # TODO: dummy device-setter for cpu backend - return - - def get_current_stream(self, device): - """ - Get stream for current device - """ - # TODO: dummy stream api for cpu backend. - return 0 - - -# the template is from triton-adapter HEAD. Wrapping the generated kernel assembly into a python module -def generate_cpu_wrapper_src(constants, signature, kernel_name): - - def _ty_to_cpp(ty): - if ty[0] == '*': - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - def _extracted_ty(ty): - if ty[0] == '*': - return "PyObject*" - return { - 'i1': 'int32_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] - - def _format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "uint32_t": "I", - "int32_t": "i", - "uint64_t": "K", - "int64_t": "L", - }[ty] - - def _generate_launcher(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - format = "iiiOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) - # to be filled - return f""" - """ - - launcher_src = _generate_launcher(constants, signature, kernel_name) - return launcher_src - - -def compile_module(launcher_src): - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - - def launch(gridX, gridY, gridZ, stream, cu_function, packed_metadata, launch_metadata, launch_enter_hook, - launch_exit_hook, *args): - # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. - # Let's compile a kernel every time. - kernel_name = packed_metadata["kernel_name"] - cache = get_cache_manager(packed_metadata["hash"]) - filename = f"{kernel_name}_cpu_launcher.so" - cache_path = cache.get_file(filename) - if cache_path is None: - asm_src = cu_function - with tempfile.TemporaryDirectory() as tmpdir: - asm_src_path = os.path.join(tmpdir, "kernel.s") - launcher_src_path = os.path.join(tmpdir, "main.cxx") - if packed_metadata["debug"]: - dump_manager = get_dump_manager(packed_metadata["hash"]) - dump_manager.put(launcher_src, "kernel_cpu_launcher.cxx", binary=False) - so_path = os.path.join(tmpdir, "kernel.so") - Path(asm_src_path).write_bytes(asm_src) - Path(launcher_src_path).write_text(launcher_src) - # Compile it together. - subprocess.check_call([ - _get_llvm_path("bin", "clang++"), launcher_src_path, asm_src_path, f"-I{py_include_dir}", - f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path - ]) - - with open(so_path, "rb") as f: - cache_path = cache.put(f.read(), filename, binary=True) - - # Load and launch the compiled kernel. - spec = importlib.util.spec_from_file_location("__triton_adapter_ref_cpu_kernel_launcher", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, packed_metadata, *args) - - return launch diff --git a/third_party/ascend/backend/device_print.h b/third_party/ascend/backend/device_print.h deleted file mode 100644 index b56910f3a..000000000 --- a/third_party/ascend/backend/device_print.h +++ /dev/null @@ -1,274 +0,0 @@ -#ifndef TRITON_DEVICE_PRINT_H -#define TRITON_DEVICE_PRINT_H - -#include "experiment/runtime/runtime/rt.h" -#include "stdio.h" - -#define LogBufferPaddingBytes 64 -#define BlockMaxSize 16 * 1024 -#define VerifyBorder(nextField, maxBuf) \ - if (nextField > maxBuf) { \ - printf("\nWARNING: out of bound! try best to print\n"); \ - return; \ - } -#define __gm__ - -namespace TTAscDebug { - -enum NodeTy { END, NORMAL, FLOAT, INT, CHAR, STRING, POINTER }; - -struct PrintPayloadData { - __gm__ char *LogWholeRegion; - unsigned BlockNum; - size_t LogBufferSize; - PrintPayloadData() - : LogWholeRegion((__gm__ char *)nullptr), LogBufferSize(0), BlockNum(0) {} -}; - -struct DebugTunnelData { - PrintPayloadData PrintData; - DebugTunnelData() {} -}; - -void PrintFormatString(int8_t *&buf, int8_t *maxbuf) { - VerifyBorder((buf + sizeof(short)), maxbuf); - short len = *(short *)buf; - buf += sizeof(len); - VerifyBorder((buf + len), maxbuf); - printf((const char *)buf); - buf += len; -} - -template -void PrintFormatString(int8_t *&buf, int8_t *maxbuf, T param) { - VerifyBorder((buf + sizeof(short)), maxbuf); - short len = *(short *)buf; - buf += sizeof(len); - VerifyBorder((buf + len), maxbuf); - printf((const char *)buf, param); - buf += len; -} - -void AnalyzeSerializedData(int8_t *buf, int logSize, int maxSize) { - int8_t *bufEndAddr = buf + logSize; - int8_t *maxbuf = buf + maxSize; - while (buf < bufEndAddr) { - VerifyBorder((buf + sizeof(int8_t)), maxbuf); - int8_t type = *(int8_t *)buf; - while (type != NodeTy::END) { - buf += sizeof(type); - switch (type) { - default: - break; - case NodeTy::NORMAL: { - PrintFormatString(buf, maxbuf); - break; - } - case NodeTy::FLOAT: { - VerifyBorder((buf + sizeof(float)), maxbuf); - float param = *(float *)buf; - buf += sizeof(param); - PrintFormatString(buf, maxbuf, param); - break; - } - case NodeTy::INT: { - VerifyBorder((buf + sizeof(long long int)), maxbuf); - long long int param = *(long long int *)buf; - buf += sizeof(param); - PrintFormatString(buf, maxbuf, param); - break; - } - case NodeTy::STRING: { - VerifyBorder((buf + sizeof(short)), maxbuf); - short strlen = *(short *)buf; - buf += sizeof(strlen); - VerifyBorder((buf + strlen), maxbuf); - char *param = reinterpret_cast(buf); - buf += strlen; - PrintFormatString(buf, maxbuf, param); - break; - } - case NodeTy::CHAR: { - VerifyBorder((buf + sizeof(char)), maxbuf); - char param = *(char *)buf; - buf += sizeof(param); - PrintFormatString(buf, maxbuf, param); - break; - } - case NodeTy::POINTER: { - VerifyBorder((buf + 8), maxbuf); - void *param = *(void **)buf; - buf += sizeof(param); - PrintFormatString(buf, maxbuf, param); - break; - } - } - VerifyBorder((buf + sizeof(int8_t)), maxbuf); - type = *(int8_t *)buf; - } - buf += 1; - } -} - -void OnHostInitialize(PrintPayloadData *PrintData, unsigned BlockNum) { - PrintData->LogBufferSize = BlockMaxSize; - PrintData->BlockNum = BlockNum; - int WholeSize = - (PrintData->LogBufferSize + LogBufferPaddingBytes) * PrintData->BlockNum; - - void *Hbm_PrintPayloadData_start_addr = NULL; - // Not sure how to use the module_id param of rtMalloc - uint16_t ModuleId = 0; - rtError_t error = - rtMalloc(reinterpret_cast(&Hbm_PrintPayloadData_start_addr), - WholeSize, RT_MEMORY_HBM, ModuleId); - if (error != RT_ERROR_NONE) { - printf("ERROR:The memory for the printing function on the device side " - "fails to be allocated."); - printf("As a result, the printing function fails!\n"); - return; - } - PrintData->LogWholeRegion = (__gm__ char *)Hbm_PrintPayloadData_start_addr; -} - -void OnHostFinish(PrintPayloadData *PrintData, rtStream_t Stream) { - if (!PrintData->LogWholeRegion) { - return; - } - std::size_t WholeSize = - (PrintData->LogBufferSize + LogBufferPaddingBytes) * PrintData->BlockNum; - char *hostMemOut2; - // Not sure how to use the module_id param of rtMalloc - uint16_t ModuleId = 0; - rtError_t error = rtMallocHost(reinterpret_cast(&hostMemOut2), - WholeSize, ModuleId); - if (error != RT_ERROR_NONE) { - printf("ERROR:The memory for the printing function on the device side " - "fails to be allocated."); - printf("As a result, the printing function fails!\n"); - return; - } - error = rtMemcpyAsync(hostMemOut2, WholeSize, PrintData->LogWholeRegion, - WholeSize, RT_MEMCPY_DEVICE_TO_HOST, Stream); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory copy of the device print on fails,"); - printf("and the printing function is invalid!\n"); - return; - } - error = rtStreamSynchronize(Stream); - if (error != RT_ERROR_NONE) { - printf("ERROR: Synchronous waiting for the device print failed.\n"); - printf("The printing function is invalid!\n"); - return; - } - char *outRaw2 = static_cast(hostMemOut2); - const char *Line = "-------------------------------------------------------"; - // Precheck if any print data is ready - for (int B = 0; B < PrintData->BlockNum; B++) { - char *Log = - (outRaw2 + (PrintData->LogBufferSize + LogBufferPaddingBytes) * B); - size_t LogSize = *reinterpret_cast(Log); - if (LogSize > 0 && LogSize <= PrintData->LogBufferSize) { - printf("LogBufferSize of each core is : %zu Bytes\n", - PrintData->LogBufferSize); - printf("%s\n", Line); - printf("----------------------HiIPU " - "Print----------------------\n"); - printf("%s\n", Line); - break; - } - } - - for (int B = 0; B < PrintData->BlockNum; B++) { - char *Log = - (outRaw2 + (PrintData->LogBufferSize + LogBufferPaddingBytes) * B); - size_t LogSize = *reinterpret_cast(Log); - if (LogSize < 0 || LogSize > PrintData->LogBufferSize) { - printf(" LOG SIZE ERROR !!! \n"); - printf(" log size needed = %zu ", LogSize); - printf(" , buf size = %zu\n", PrintData->LogBufferSize); - LogSize = PrintData->LogBufferSize; - continue; - } - if (LogSize == 0) { - continue; - } - printf("==> Block %d, LogSize = %zu Bytes\n", B, LogSize); - int8_t *Buf = reinterpret_cast(Log + LogBufferPaddingBytes); - AnalyzeSerializedData(Buf, LogSize, PrintData->LogBufferSize); - printf("\n"); - printf("%s\n", Line); - } - error = rtFree(PrintData->LogWholeRegion); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory free of the device print fails\n"); - return; - } - error = rtFreeHost(hostMemOut2); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory free of the device print fails\n"); - return; - } -} - -DebugTunnelData *Open(unsigned BlockNum) { - DebugTunnelData debugTunnelDataForHost; - OnHostInitialize(&(debugTunnelDataForHost.PrintData), BlockNum); - void *Hbm_PrintPayloadData_start_addr = NULL; - // Not sure how to use the module_id param of rtMalloc - uint16_t ModuleId = 0; - rtError_t error = - rtMalloc(reinterpret_cast(&Hbm_PrintPayloadData_start_addr), - sizeof(debugTunnelDataForHost), RT_MEMORY_HBM, ModuleId); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory for the printing function on the device side " - "fails to be allocated."); - printf("As a result, the printing function fails!\n"); - return nullptr; - } - if (Hbm_PrintPayloadData_start_addr == nullptr) { - printf("WARNING: failed to allocate DebugTunnelData memory\n"); - return nullptr; - } - error = rtMemcpy(Hbm_PrintPayloadData_start_addr, - sizeof(debugTunnelDataForHost), &debugTunnelDataForHost, - sizeof(debugTunnelDataForHost), RT_MEMCPY_HOST_TO_DEVICE); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory copy of the device print on fails, "); - printf("and the printing function is invalid!\n"); - return nullptr; - } - return reinterpret_cast(Hbm_PrintPayloadData_start_addr); -} - -void Close(DebugTunnelData *DTData, rtStream_t Stream) { - if (!DTData) { - return; - } - DebugTunnelData debugTunnelDataForHost; - rtError_t error = rtStreamSynchronize(Stream); - if (error != RT_ERROR_NONE) { - printf("ERROR: Synchronous waiting for the device print failed.\n"); - printf("The printing function is invalid!\n"); - } - error = - rtMemcpy(&debugTunnelDataForHost, sizeof(debugTunnelDataForHost), DTData, - sizeof(debugTunnelDataForHost), RT_MEMCPY_DEVICE_TO_HOST); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory copy of the device print on fails, "); - printf("and the printing function is invalid!\n"); - return; - } - OnHostFinish(&(debugTunnelDataForHost.PrintData), Stream); - - error = rtFree(DTData); - if (error != RT_ERROR_NONE) { - printf("ERROR: The memory free of the device print fails, "); - printf("and the device print is invalid!\n"); - return; - } -} - -} // namespace TTAscDebug - -#endif diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py deleted file mode 100644 index f00a1e3f2..000000000 --- a/third_party/ascend/backend/driver.py +++ /dev/null @@ -1,504 +0,0 @@ -from pathlib import Path -import tempfile -import os -import subprocess -import sysconfig -from typing import Optional -import functools -import hashlib -from triton.runtime.cache import get_cache_manager, get_dump_manager -from triton.backends.driver import DriverBase -from triton.backends.compiler import GPUTarget -from triton.backends.ascend.utils import _build_npu_ext, _check_cxx11_abi - - -class NPUUtils(object): - - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(NPUUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - src = Path(os.path.join(dirname, "npu_utils.cpp")).read_text() - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - fname = "npu_utils.so" - cache_path = cache.get_file(fname) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "npu_utils.cpp") - with open(src_path, "w") as f: - f.write(src) - so = _build_npu_ext("npu_utils", src_path, tmpdir) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location("npu_utils", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.npu_utils_mod = mod - - def load_binary(self, name, kernel, shared, device): - fnname, mix_mode = name.split() - return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode) - - @functools.lru_cache() - def get_device_properties(self, device): - # temperoarily added "max_shared_mem" properties to avoid triton-compiler complain - # fetch available memory at runtime - num_aic = self.get_aicore_num() - num_aiv = num_aic * 2 - return {"max_shared_mem": 1, "num_aicore": num_aic, "num_vectorcore": num_aiv} - - @functools.lru_cache() - def get_arch(self): - # temporarily return empty arch descriptor - return self.npu_utils_mod.get_arch() - - @functools.lru_cache() - def get_aicore_num(self): - # temporarily return empty arch descriptor - return self.npu_utils_mod.get_aicore_num() - - -class NPULauncher(object): - - def __init__(self, src, metadata): - debug_mode = metadata.debug - workspace_size = int(metadata.workspace_size) \ - if hasattr(metadata, 'workspace_size') else -1 - constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - wrapper_src = generate_npu_wrapper_src(constants, signature, \ - workspace_size) - so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) - # initialize launcher - import importlib.util - spec = importlib.util.spec_from_file_location("__triton_launcher", so_launcher_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.launch = getattr(mod, "launch") - - def __call__(self, *args, **kwargs): - profiler_registered = self.launch(*args, **kwargs) - import triton - triton.backends.ascend.utils.TRITON_PROFILER_REGISTERED = True if profiler_registered == 1 else False - - -class NPUDriver(DriverBase): - - def __init__(self): - self.utils = NPUUtils() - self.launcher_cls = NPULauncher - super().__init__() - - @classmethod - def is_active(cls): - - def test_npucompiler(): - from triton.backends.ascend.utils import _get_bisheng_path - npucompiler = _get_bisheng_path() - targets = subprocess.check_output([npucompiler, "-print-targets"]).decode().strip().split() - return "hiipu64" in targets - - try: - return test_npucompiler() - except Exception as e_npucompiler: - import warnings - red = "\x1b[31;20m" - reset = "\x1b[0m" - warnings.warn(red + str(e_npucompiler) + reset) - return False - - def get_current_target(self): - backend = "npu" - arch = self.utils.get_arch() - warp_size = 0 - return GPUTarget(backend, arch, warp_size) - - def get_current_device(self): - """ - Get current device - """ - import torch - import torch_npu - return torch.npu.current_device() - - def set_current_device(self, device): - """ - Set current device as the given device - """ - import torch - import torch_npu - return torch.npu.set_device(device) - - def get_current_stream(self, device: Optional[int] = None) -> int: - """ - Get stream for current device - """ - # According to torch_npu, the content of a torch.npu.Stream is essentilly an rtStream_t - # TODO: use CANN API instead of torchnpu - import torch - import torch_npu - if device is None: - device = self.get_current_device() - return torch.npu.current_stream(device).npu_stream - - def get_benchmarker(self): - from triton.testing import do_bench - return do_bench - - def get_device_interface(self): - import torch - return torch.npu - - def get_empty_cache_for_benchmark(self): - import torch - cache_size = 192 * 1024 * 1024 - return torch.empty(cache_size // 4, dtype=torch.int, device='npu') - - -def make_npu_launcher_stub(src, debug=False): - """ - Generate the launcher stub to launch the kernel - """ - # try to get cached file - so_cache_key = hashlib.sha256(src.encode("utf-8")).hexdigest() - so_cache_manager = get_cache_manager(so_cache_key) - # append the cxx11_abi value to the launcher name to avoid - # linking to a launcher with wrong cxx11_abi. - use_cxx11_abi = _check_cxx11_abi() - name = f"launcher_cxx11abi{use_cxx11_abi}" - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so_name = f"{name}{suffix}" - - if debug: - dump_manager = get_dump_manager(so_cache_key) - print(f"Dumping {name}.cxx to {dump_manager.cache_dir}") - dump_manager.put(src, f"{name}.cxx", binary=False) - - cache_path = so_cache_manager.get_file(so_name) - if cache_path is not None: - return cache_path - - with tempfile.TemporaryDirectory() as tmpdir: - if debug: - so_cache_manager.put(src, f"{name}.cxx", binary=False) - src_path = os.path.join(tmpdir, f"{name}.cxx") - with open(src_path, "w") as f: - f.write(src) - so = _build_npu_ext(name, src_path, tmpdir, kernel_launcher="torch") - if debug: - with open(so, "rb") as f: - return dump_manager.put(f.read(), so_name, binary=True) - with open(so, "rb") as f: - return so_cache_manager.put(f.read(), so_name, binary=True) - - -# the template is from triton-adapter HEAD. Wrapping the generated kernel binary into a python module -def generate_npu_wrapper_src(constants, signature, workspace_size): - import os - - def _ty_to_cpp(ty): - if ty[0] == '*': - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - def _extracted_ty(ty): - if ty[0] == '*': - return "PyObject*" - return { - 'i1': 'int32_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] - - def _format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "uint32_t": "I", - "int32_t": "i", - "uint64_t": "K", - "int64_t": "L", - }[ty] - - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - """ - args: - int gridX, gridY, gridZ; - rtStream_t stream; - const void *functon; - PyObject* packed_metadata, *launch_metadata; - PyObject* launch_enter_hook, *launch_exit_hook; - *args_expand - """ - format = "iiiKKOOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) - - grid_info = {'X': 'i32', 'Y': 'i32', 'Z': 'i32'} - - enable_device_print = os.getenv("TRITON_DEVICE_PRINT", 'false').lower() in ('true', '1') - - return f""" -#include -#include -#include -#include - -#define PY_SSIZE_T_CLEAN -#include -#include -#include "experiment/runtime/runtime/rt.h" -{'#include "device_print.h"' if enable_device_print else ''} - -extern "C" {{ - - typedef int (* callback)(unsigned int type, void* data, unsigned int len); - extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); - extern unsigned long int MsprofSysCycleTime(); - extern int MsprofRegisterCallback(unsigned int moduleId, callback handle); - static unsigned int __MsprofFlagL0 = 0; - static unsigned int __MsprofFlagL1 = 0; - - int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) {{ - if ((CtrlData == nullptr) || (DataLen == 0U)) {{ - return 1; - }} - - if (CtrlType == 1) {{ - MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData); - if (handle->type >= 6) // 6 is not used here - return 1; - if (handle->type == 1) {{ // init - 0 , start - 1 - __MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0; - __MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0; - }} - }} - return 0; - }} -}} - -typedef struct _DevicePtrInfo {{ - void *dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); - if(!ptr_info.dev_ptr) - return ptr_info; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - return ptr_info; -}} - -static void _launch(const char* kernelName, const void* func, rtStream_t stream, int gridX, int gridY, int gridZ, int *profilerRegistered, {arg_decls}) {{ - // only 1D parallelization is supported for NPU - // Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride - // base_ptr offset shape and stride are not used, arbitrarily set for now - std::string name = ""; - name.append(kernelName); - if (!(*profilerRegistered)) {{ - MsprofRegisterCallback(8, ProfCtrlHandle); // 8 - CCE defined in msprof headerfile slog.h - *profilerRegistered = 1; - }} - auto launch_call = [=]() {{ - uint32_t blockNum = gridX * gridY * gridZ; - {'TTAscDebug::DebugTunnelData *DTData = TTAscDebug::Open(blockNum);' if enable_device_print else ''} - rtError_t ret; - void *ffts_addr = NULL; - uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); - if (ret != RT_ERROR_NONE) {{ - return ret; - }} - // stub argument for workspace - void *workspace_addr = NULL; - {f''' - uint16_t ModuleId = 0; - uint64_t totalWorkSpaceSize = {workspace_size} * blockNum; - ret = rtMalloc(reinterpret_cast(&workspace_addr), - totalWorkSpaceSize, RT_MEMORY_HBM, ModuleId); - if (ret != RT_ERROR_NONE) {{ - return ret; - }} - ''' if workspace_size > 0 else ''} - struct __attribute__((packed)) {{ - void* ffts_addr __attribute__((aligned(8))); - void* workspace_addr __attribute__((aligned(8))); - {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8})));' for i, ty in signature.items() if i not in constants)} - {' '.join(f'{_ty_to_cpp(ty)} grid{mark} __attribute__((aligned(4)));' for mark, ty in grid_info.items())} - {'void* DTData __attribute__((aligned(8)));' if enable_device_print else ''} - }} args = {{ - static_cast(ffts_addr), - static_cast(workspace_addr), - {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants)}, - {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(grid{mark})' for mark, ty in grid_info.items())} - {', static_cast(DTData)' if enable_device_print else ''} - }}; - unsigned long int beginTime = 0; - unsigned long int endTime = 0; - unsigned long int opName = 0; - unsigned int threadId = 0; - char* kernelName = const_cast(name.c_str()); - size_t length = name.length(); - // FIXME: to avoid bug in msprof, currently we disable these checks - // if (__MsprofFlagL0 || __MsprofFlagL1) {{ - {{ - beginTime = MsprofSysCycleTime(); - }} - ret = rtKernelLaunch(func, blockNum, static_cast(&args), sizeof(args), NULL, stream); - {'TTAscDebug::Close(DTData, stream);' if enable_device_print else ''} - // FIXME: to avoid bug in msprof, currently we disable these checks - // if (__MsprofFlagL0 || __MsprofFlagL1) {{ - {{ - endTime = MsprofSysCycleTime(); - opName = MsprofGetHashId(kernelName, length); - threadId = (unsigned int)(syscall(SYS_gettid)); - MsprofApi info; - info.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM - info.level = 10000; //MSPROF_REPORT_NODE_LEVEL - info.type = 5; //MSPROF_REPORT_NODE_LAUNCH_TYPE - info.threadId = threadId; - info.reserve = 0; - info.beginTime = beginTime; - info.endTime = endTime; - info.itemId = opName; - MsprofReportApi(0, &info); - }} - // FIXME: to avoid bug in msprof, currently we disable these checks - // if (__MsprofFlagL1) {{ - {{ - MsprofCompactInfo nodeBasicInfo; - nodeBasicInfo.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM - nodeBasicInfo.level = 10000; //MSPROF_REPORT_NODE_LEVEL - nodeBasicInfo.type = 0; //MSPROF_REPORT_NODE_BASIC_INFO_TYPE - nodeBasicInfo.threadId = threadId; - nodeBasicInfo.timeStamp = endTime; - nodeBasicInfo.data.nodeBasicInfo.opName = opName; - nodeBasicInfo.data.nodeBasicInfo.taskType = 0; //MSPROF_GE_TASK_TYPE_AI_CORE - nodeBasicInfo.data.nodeBasicInfo.opType = opName; - nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX; - MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo)); - }} - return ret; - }}; - at_npu::native::OpCommand cmd; - cmd.Name(name.c_str()) - .SetCustomHandler(launch_call) - .Run(); -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - int gridX, gridY, gridZ; - rtStream_t stream; - const void *function; - PyObject *packedMetadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - {' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple( - args, \"{format}\", - &gridX, &gridY, &gridZ, &stream, &function, - &packedMetadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook - {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''} - ) - ) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ - return NULL; - }} - - // get kernel_name - PyObject *kernelNameObj = PyDict_GetItemString(packedMetadata, "kernel_name"); - const char *kernelName = PyUnicode_AsUTF8(kernelNameObj); - PyObject *profilerRegisteredObj = PyDict_GetItemString(packedMetadata, "profiler_registered"); - int profilerRegistered = PyObject_IsTrue(profilerRegisteredObj); - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0]=="*" else "" for i, ty in signature.items()])}; - _launch(kernelName, function, stream, gridX, gridY, gridZ, &profilerRegistered, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); - if (PyErr_Occurred()) {{ - return NULL; - }} - if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ - return NULL; - }} - - return Py_BuildValue("I", profilerRegistered); -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" diff --git a/third_party/ascend/backend/name.conf b/third_party/ascend/backend/name.conf deleted file mode 100644 index 3fd20dbae..000000000 --- a/third_party/ascend/backend/name.conf +++ /dev/null @@ -1 +0,0 @@ -huawei diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp deleted file mode 100644 index bcfc61c50..000000000 --- a/third_party/ascend/backend/npu_utils.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include - -#include -#include -#include -#include - -#include "experiment/runtime/runtime/rt.h" - -// Use map to differentiate same name functions from different binary -static std::unordered_map registered_names; -static std::unordered_map> func_stubs; - -static std::tuple -registerKernel(const char *name, const void *data, size_t data_size, int shared, - int device, const char *kernel_mode_str) { - rtError_t rtRet; - - rtDevBinary_t devbin; - devbin.data = data; - devbin.length = data_size; - const std::string kernel_mode{kernel_mode_str}; - if (kernel_mode == "aiv") - devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; - else - devbin.magic = RT_DEV_BINARY_MAGIC_ELF; - devbin.version = 0; - - rtRet = rtSetDevice(device); - if (rtRet != RT_ERROR_NONE) { - printf("rtSetDevice failed, 0x%x\n", rtRet); - return {NULL, NULL}; - } - - void *devbinHandle = NULL; - rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); - if (rtRet != RT_ERROR_NONE) { - printf("rtDevBinaryRegister failed, 0x%x\n", rtRet); - return {NULL, NULL}; - } - - std::string stubName = name; - stubName += "_" + std::to_string(registered_names[name]); - registered_names[name]++; - auto registered = func_stubs.emplace(stubName, std::make_unique(0)); - void *func_stub_handle = registered.first->second.get(); - rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(), - (void *)name, 0); - if (rtRet != RT_ERROR_NONE) { - printf("rtFunctionRegister failed(stubName = %s), 0x%x\n", stubName.c_str(), - rtRet); - exit(1); - return {NULL, NULL}; - } - - return std::make_tuple(devbinHandle, func_stub_handle); -} - -static PyObject *loadKernelBinary(PyObject *self, PyObject *args) { - const char *name; // kernel name - const char *data; // binary pointer - Py_ssize_t data_size; // binary size - int shared; // shared_memory(meaningless now) - int device; // device ID - const char *kernel_mode; // kernel mode - - if (!PyArg_ParseTuple(args, "ss#iis", &name, &data, &data_size, &shared, - &device, &kernel_mode)) { - return NULL; - } - - auto [module_handle, func_handle] = - registerKernel(name, data, data_size, shared, device, kernel_mode); - - uint64_t mod = reinterpret_cast(module_handle); - uint64_t func = reinterpret_cast(func_handle); - if (PyErr_Occurred()) { - return NULL; - } - - return Py_BuildValue("(KKii)", mod, func, 0, 0); -} - -static PyObject *getArch(PyObject *self, PyObject *args) { - char name[64] = {'\0'}; - - rtError_t rtRet = rtGetSocVersion(name, 64); - - if (rtRet != RT_ERROR_NONE) { - printf("rtGetSocVersion failed, 0x%x", rtRet); - return NULL; - } - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("s", name); -} - -static PyObject *getAiCoreNum(PyObject *self, PyObject *args) { - uint32_t aiCoreCnt; - - rtError_t rtRet = rtGetAiCoreCount(&aiCoreCnt); - - if (rtRet != RT_ERROR_NONE) { - printf("rtGetAiCoreCount failed, 0x%x", rtRet); - return NULL; - } - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("I", aiCoreCnt); -} - -static PyMethodDef NpuUtilsMethods[] = { - {"load_kernel_binary", loadKernelBinary, METH_VARARGS, - "Load NPU kernel binary into NPU driver"}, - {"get_arch", getArch, METH_VARARGS, "Get soc version of NPU"}, - // sentinel - {"get_aicore_num", getAiCoreNum, METH_VARARGS, "Get the number of AI core"}, - {NULL, NULL, 0, NULL}}; - -static PyModuleDef ModuleDef = { - PyModuleDef_HEAD_INIT, "npu_utils", - "Utilities for fetching NPU device info and preparing kernel binary", -1, - NpuUtilsMethods}; - -PyMODINIT_FUNC PyInit_npu_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - - PyModule_AddFunctions(m, NpuUtilsMethods); - return m; -} diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py deleted file mode 100644 index 826ae3a8c..000000000 --- a/third_party/ascend/backend/utils.py +++ /dev/null @@ -1,203 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import re -import os -from pathlib import Path -import functools -import sysconfig -import shutil -import subprocess - -TRITON_PROFILER_REGISTERED = False - - -def downgrade_llir(llir): - llir = _downgrade_mem_attrs(llir) - llir = _downgrade_stacksaverestore_intrinsics(llir) - return llir - - -def _downgrade_mem_attrs(llir: str): - memory_pattern = r"memory\([^()]*\)" - - def replace_mem_attr(m): - attrs = m[0][7:-1].split(",") - if len(attrs) == 0: - return "readnone" - loc_map = {"argmem": 1, "inaccessiblemem": 2, "other": 4} - loc_attr = 0 - rw_map = {"readwrite": 3, "write": 2, "read": 1, "none": 0} - rw_attr = 0 - for attr_pair in attrs: - pair = attr_pair.split(":") - assert len(pair) <= 2 - if len(pair) == 1: - rw = rw_map[pair[0].strip()] - loc = loc_map["other"] # all location - else: - rw = rw_map[pair[1].strip()] - loc_str = pair[0].strip() - if loc_str == "argmem" or loc_str == "inaccessiblemem": - loc = loc_map[loc_str] - else: - loc = loc_map["other"] - if rw > 0: - loc_attr = loc_attr | loc - rw_attr = rw_attr | rw - rev_rw_map = {0: "readnone", 1: "readonly", 2: "writeonly"} - if rw_attr in rev_rw_map: - rw_attr_str = rev_rw_map[rw_attr] - else: - rw_attr_str = "" - rev_loc_map = {1: "argmemonly", 2: "inaccessiblememonly", 3: "inaccessiblemem_or_argmemonly"} - if loc_attr in rev_loc_map: - loc_attr_str = rev_loc_map[loc_attr] - else: - loc_attr_str = "" - return rw_attr_str + " " + loc_attr_str - - return re.sub(memory_pattern, replace_mem_attr, llir) - - -def _downgrade_stacksaverestore_intrinsics(llir: str): - llir = re.sub(r"llvm\.stacksave\.\w+", "llvm.stacksave", llir) - llir = re.sub(r"llvm\.stackrestore\.\w+", "llvm.stackrestore", llir) - return llir - - -def _get_triton_adapter_opt_path() -> str: - path = os.path.dirname(__file__) - path = os.path.join(path, "triton-adapter-opt") - return path - - -def _get_mlir_path(path: str, *paths) -> str: - root_path = os.getenv("MLIR_ROOT", "") - if root_path == "": - raise EnvironmentError("MLIR_ROOT is not set.") - return os.path.join(root_path, path, *paths) - - -def _get_llvm_path(path: str, *paths) -> str: - root_path = os.getenv("LLVM_ROOT", "") - if root_path == "": - raise EnvironmentError("LLVM_ROOT is not set.") - return os.path.join(root_path, path, *paths) - - -def _get_npucompiler_path() -> str: - npu_compiler_path = shutil.which("bishengir-compile") - if npu_compiler_path is None: - npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") - if npu_compiler_root is None: - raise EnvironmentError("Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH.") - npu_compiler_path = os.path.join(npu_compiler_root, "npuc") - return npu_compiler_path - - -def _get_bisheng_path() -> str: - bisheng_path = shutil.which("bisheng") - if bisheng_path is None: - npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") - if npu_compiler_root is None: - raise EnvironmentError("Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH") - bisheng_path = os.path.join(npu_compiler_root, "ccec") - return bisheng_path - - -@functools.lru_cache(None) -def _get_ascend_path() -> str: - path = os.getenv("ASCEND_HOME_PATH", "") - if path == "": - raise EnvironmentError("ASCEND_HOME_PATH is not set, source /set_env.sh first") - return Path(path) - - -def _is_ascend_sanitizer_enabled() -> bool: - return os.getenv("TRITON_ENABLE_SANITIZER", 'false').lower() in ('true', '1') - - -def _build_npu_ext(obj_name: str, src_path, src_dir, *, kernel_launcher=None) -> str: - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so_path = os.path.join(src_dir, f"{obj_name}{suffix}") - - cxx = os.environ.get("CC") - if cxx is None: - clangxx = shutil.which("clang++") - gxx = shutil.which("g++") - cxx = clangxx if clangxx is not None else gxx - if cxx is None: - raise RuntimeError("Failed to find C++ compiler") - cc_cmd = [cxx, src_path] - # disable all warnings - cc_cmd += [f"-w"] - # find the python library - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - cc_cmd += [f"-I{py_include_dir}"] - # device_print.h - cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"] - # find the ascend library - asc_path = _get_ascend_path() - cc_cmd += [ - f"-I{os.path.join(asc_path, 'include')}", - f"-I{os.path.join(asc_path, 'include/experiment')}", - f"-I{os.path.join(asc_path, 'include/experiment/msprof')}", - f"-L{os.path.join(asc_path, 'lib64')}", - "-lruntime", - "-lascendcl", - ] - - if kernel_launcher == "torch": - import torch - import torch_npu - torch_path = os.path.dirname(os.path.realpath(torch.__file__)) - torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) - use_cxx11_abi = _check_cxx11_abi() - cc_cmd += [ - f"-I{os.path.join(torch_path, 'include')}", - f"-I{os.path.join(torch_npu_path, 'include')}", - f"-L{os.path.join(torch_npu_path, 'lib')}", - "-ltorch_npu", - f"-D_GLIBCXX_USE_CXX11_ABI={use_cxx11_abi}", - ] - - cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", so_path] - - ret = subprocess.check_call(cc_cmd) - - if ret == 0: - return so_path - else: - raise RuntimeError("Failed to compile " + src_path) - - -def _get_kernel_target(metadata: dict): - if "target" not in metadata: - raise Exception("No target provided!") - sub_target = metadata["target"].arch - assert isinstance(sub_target, str) - if sub_target.startswith('Ascend910B'): - mix_mode = metadata["mix_mode"] - if mix_mode.lower().strip("_").startswith("aiv"): - return "ascend_910b_vec", "c220-vec", "aiv" - elif mix_mode.lower().strip("_").startswith("aic"): - return "ascend_910b_cube", "c220-cube", "aic" - else: - return "ascend_910b", "c220", "mix" - elif sub_target.startswith('Ascend910'): - return "ascend_910", "c100", "mix" - else: - raise NotImplementedError(f"NPU subtarget {sub_target} not supported yet") - - -def _check_cxx11_abi(): - import torch - return 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 diff --git a/third_party/ascend/language/ascend/__init__.py b/third_party/ascend/language/ascend/__init__.py deleted file mode 100644 index 229b57d87..000000000 --- a/third_party/ascend/language/ascend/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import libdevice - -__all__ = ["libdevice"] diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/ascend/libdevice.py deleted file mode 100644 index db22bf7cc..000000000 --- a/third_party/ascend/language/ascend/libdevice.py +++ /dev/null @@ -1,135 +0,0 @@ -from triton.language import core - - -@core.extern -def reciprocal(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_recipf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_recipDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def log1p(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_log1pf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_log1pDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def relu(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_reluf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_reluDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def isinf(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_isinf", core.dtype("int1")), - (core.dtype("fp16"), ): ("__hmf_isinf", core.dtype("int1")), - (core.dtype("bf16"), ): ("__hmf_isinf", core.dtype("int1")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def tan(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_tanf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_tanDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def atan(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_atanf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_atanDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def tanh(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def ilogb(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_ilogbf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_ilogbDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def ldexp(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_ldexpf", core.dtype("fp32")), - (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_ldexpDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def pow(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), - (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powf", core.dtype("fp16")), - (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powf", core.dtype("bf16")), - (core.dtype("int64"), core.dtype("int64")): ("__hmf_powi", core.dtype("int64")), - (core.dtype("int32"), core.dtype("int32")): ("__hmf_powi", core.dtype("int32")), - (core.dtype("int16"), core.dtype("int16")): ("__hmf_powi", core.dtype("int16")), - (core.dtype("int8"), core.dtype("int8")): ("__hmf_powi", core.dtype("int8")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def isnan(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_isnan", core.dtype("int1")), - (core.dtype("fp16"), ): ("__hmf_isnan", core.dtype("int1")), - (core.dtype("bf16"), ): ("__hmf_isnan", core.dtype("int1")), - }, is_pure=True, _builder=_builder) - - -@core.extern -def flip(arg0, arg1=None, _builder=None): - if arg1 == None: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("bf16"), ): ("__hmf_flipDhb", core.dtype("bf16")), - (core.dtype("fp16"), ): ("__hmf_flipDh", core.dtype("fp16")), - (core.dtype("fp32"), ): ("__hmf_flipf", core.dtype("fp32")), - (core.dtype("int8"), ): ("__hmf_flipi8", core.dtype("int8")), - (core.dtype("int16"), ): ("__hmf_flipi16", core.dtype("int16")), - (core.dtype("int32"), ): ("__hmf_flipi32", core.dtype("int32")), - (core.dtype("uint32"), ): ("__hmf_flipui32", core.dtype("uint32")), - (core.dtype("int64"), ): ("__hmf_flipi64", core.dtype("int64")), - }, is_pure=True, _builder=_builder) - - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("bf16"), core.dtype("int32")): ("__hmf_flipDhb", core.dtype("bf16")), - (core.dtype("fp16"), core.dtype("int32")): ("__hmf_flipDh", core.dtype("fp16")), - (core.dtype("fp32"), core.dtype("int32")): ("__hmf_flipf", core.dtype("fp32")), - (core.dtype("int8"), core.dtype("int32")): ("__hmf_flipi8", core.dtype("int8")), - (core.dtype("int16"), core.dtype("int32")): ("__hmf_flipi16", core.dtype("int16")), - (core.dtype("int32"), core.dtype("int32")): ("__hmf_flipi32", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("int32")): ("__hmf_flipui32", core.dtype("uint32")), - (core.dtype("int64"), core.dtype("int32")): ("__hmf_flipi64", core.dtype("int64")), - }, is_pure=True, _builder=_builder) diff --git a/third_party/ascend/triton-adapter/CMakeLists.txt b/third_party/ascend/triton-adapter/CMakeLists.txt deleted file mode 100644 index fcc2348b7..000000000 --- a/third_party/ascend/triton-adapter/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -option(TRITON_ADAPTER_BUILD_CPU_BACKEND "Build triton-adapter CPU backend" ON) - -set(TRITON_ADAPTER_SOURCE_DIR ".") -set(TRITON_ADAPTER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") - -include_directories(./include) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files -add_subdirectory(include) -add_subdirectory(lib) -add_subdirectory(tools) - -if (TRITON_ADAPTER_BUILD_CPU_BACKEND) - add_triton_plugin(TritonAdapter triton_adapter.cc LINK_LIBS TritonToLinalg) -endif() diff --git a/third_party/ascend/triton-adapter/include/CMakeLists.txt b/third_party/ascend/triton-adapter/include/CMakeLists.txt deleted file mode 100644 index 64ac15761..000000000 --- a/third_party/ascend/triton-adapter/include/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(TritonToLinalg) diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h deleted file mode 100644 index 0bf121049..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h +++ /dev/null @@ -1,317 +0,0 @@ -#ifndef TRITON_ADAPTER_ARGMINMAXCONVERTER_H -#define TRITON_ADAPTER_ARGMINMAXCONVERTER_H - -#include "Utils/Utils.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "ConversionPatterns.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Transforms/DialectConversion.h" - -#define DEBUG_TYPE "triton-to-linalg" - -#include "llvm/Support/Debug.h" - -namespace TTOpConverters { -using namespace mlir; -using namespace triton; - -template -class ArgMinMaxBaseConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchTieBreakResult(Value currValue, Value currIndex, - Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, - Value &tileBreakValue) const { - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto eqCmpOp = dyn_cast(*it); - if (eqCmpOp) { - if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ || - currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { - return failure(); - } - } - - auto eqCmpIOp = dyn_cast(*it++); - if (eqCmpIOp) { - if (eqCmpIOp.getPredicate() != arith::CmpIPredicate::eq || - currValue != eqCmpIOp.getLhs() || reduceValue != eqCmpIOp.getRhs()) { - return failure(); - } - } - - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto sltCmpOp = dyn_cast(*it++); - if (!sltCmpOp || sltCmpOp.getPredicate() != arith::CmpIPredicate::slt || - currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { - return failure(); - } - - // matching: %13 = arith.andi %11, %12 : i1 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto andOp = dyn_cast(*it++); - - Value cmpOp; - if (eqCmpOp) - cmpOp = eqCmpOp; - else - cmpOp = eqCmpIOp; - - if (!andOp || andOp.getLhs() != cmpOp || andOp.getRhs() != sltCmpOp) { - return failure(); - } - - tileBreakValue = andOp; - return success(); - } - - LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, - Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, - Value &shouldUpdate) const { - Value tieResult; - if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, - reduceIndex, it, tieResult))) { - LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); - return failure(); - } - - Value comparisonResult; - if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, - reduceIndex, it, comparisonResult))) { - LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto orOp = dyn_cast(*it++); - if (!orOp || orOp.getLhs() != comparisonResult || - orOp.getRhs() != tieResult) { - return failure(); - } - - shouldUpdate = orOp; - return success(); - } - - Value getInitTensor(ConversionPatternRewriter &rewriter, - ArrayRef shape, Value fillValue, - Location loc) const { - Value initTensor = - rewriter.create(loc, shape, fillValue.getType()); - return rewriter - .create(loc, ValueRange{fillValue}, - ValueRange{initTensor}) - .result(); - } - -public: - ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} - - LogicalResult match(triton::ReduceOp op) const override final { - if (op.getBody()->getNumArguments() != 4) { - return failure(); - } - - auto block = op.getBody(); - auto ops = block->without_terminator(); - - Value currValue = block->getArgument(0); - Value currIndex = block->getArgument(1); - Value reduceValue = block->getArgument(2); - Value reduceIndex = block->getArgument(3); - - auto opsIt = ops.begin(); - Value shouldUpdate; - if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, - reduceIndex, opsIt, shouldUpdate))) { - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); - auto valueSelectOp = dyn_cast(*opsIt++); - if (!valueSelectOp || valueSelectOp.getCondition() != shouldUpdate || - currValue != valueSelectOp.getTrueValue() || - reduceValue != valueSelectOp.getFalseValue()) { - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); - auto indexSelectOp = dyn_cast(*opsIt++); - if (indexSelectOp) { - if (indexSelectOp.getCondition() != shouldUpdate || - currIndex != indexSelectOp.getTrueValue() || - reduceIndex != indexSelectOp.getFalseValue()) { - return failure(); - } - } else { - return failure(); - } - if (!indexSelectOp || indexSelectOp.getCondition() != shouldUpdate || - currIndex != indexSelectOp.getTrueValue() || - reduceIndex != indexSelectOp.getFalseValue()) { - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); - auto termOp = dyn_cast(*opsIt++); - if (!(termOp && termOp == block->getTerminator() && - termOp.getOperands() == - ArrayRef{valueSelectOp, indexSelectOp})) { - return failure(); - } - return success(); - } - - void rewrite(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override final { - auto loc = op.getLoc(); - auto elemTypes = op.getElementTypes(); - - auto valueType = elemTypes[0]; - // tl.argmin reorder - auto block = op.getBody(); - if (isa(valueType)) { - arith::CmpFOp cmpFOp; - block->walk([&](arith::CmpFOp cmpOp) { - auto pred = cmpOp.getPredicate(); - if (pred == arith::CmpFPredicate::OEQ || - pred == arith::CmpFPredicate::ONE || - pred == arith::CmpFPredicate::UEQ || - pred == arith::CmpFPredicate::UNE) { - return WalkResult::advance(); - } else if (pred == arith::CmpFPredicate::OGT || - pred == arith::CmpFPredicate::OLT || - pred == arith::CmpFPredicate::UGT || - pred == arith::CmpFPredicate::ULT) { - cmpFOp = cmpOp; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - cmpFOp->moveBefore(block, block->getOperations().begin()); - } else if (isa(valueType)) { - arith::CmpIOp cmpIOp; - block->walk([&](arith::CmpIOp cmpOp) { - auto pred = cmpOp.getPredicate(); - if (pred == arith::CmpIPredicate::eq || - pred == arith::CmpIPredicate::ne) { - return WalkResult::advance(); - } else if (pred == arith::CmpIPredicate::sgt || - pred == arith::CmpIPredicate::slt || - pred == arith::CmpIPredicate::ugt || - pred == arith::CmpIPredicate::ult) { - if (cmpOp.getLhs() == block->getArgument(0) && - cmpOp.getRhs() == block->getArgument(2)) { - cmpIOp = cmpOp; - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - cmpIOp->moveBefore(block, block->getOperations().begin()); - } - - TypedAttr valueAttr; - if (isa(valueType)) { - valueAttr = rewriter.getFloatAttr(valueType, T::getBaseReductionValue()); - } else if (isa(valueType)) { - // TODO: support other type of int - valueAttr = - rewriter.getIntegerAttr(valueType, T::getBaseReductionIntValue()); - } - - auto valuesAccBaseVal = - rewriter.create(loc, valueType, valueAttr); - - auto indexType = elemTypes[1]; - auto indicesAccBaseVal = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, -1)); - - auto valueResultType = dyn_cast(op.getType(0)); - const auto isScalarReduce = valueResultType == nullptr; - SmallVector reductionResultShape{ - isScalarReduce ? SmallVector{} - : SmallVector(valueResultType.getShape())}; - - SmallVector outputs{ - getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), - getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; - - auto linalgOp = rewriter.create( - loc, adaptor.getOperands(), outputs, - SmallVector{adaptor.getAxis()}, - [&](OpBuilder &b, Location loc, ValueRange inputs) { - assert(inputs.size() == 4); - - auto tritonReduceBlock = op.getBody(); - IRMapping mapping; - mapping.map(tritonReduceBlock->getArguments(), inputs); - - for (auto &op : tritonReduceBlock->without_terminator()) { - b.clone(op, mapping); - } - - auto tritonYield = tritonReduceBlock->getTerminator(); - auto results = - llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { - return mapping.lookup(val); - }); - b.create(loc, results); - }); - - // before we rewrite the argmax reduce op, we know it has return value - // so addReduceWithIndexAttrIfNeeded won't fail - // but ignoring it will lead to compiling failure - auto logicalResult = addReduceWithIndexAttrIfNeeded(rewriter, linalgOp); - - if (isScalarReduce) { - SmallVector reduceResults{ - rewriter.create( - loc, valueType, linalgOp.getResults()[0], ValueRange{}), - rewriter.create( - loc, indexType, linalgOp.getResults()[1], ValueRange{})}; - rewriter.replaceOp(op, reduceResults); - } else { - rewriter.replaceOp(op, linalgOp); - } - } -}; - -class ArgMinConverter : public ArgMinMaxBaseConverter { -public: - static LogicalResult matchComparisonResult(Value currValue, Value currIndex, - Value reduceValue, - Value reduceIndex, - mlir::Block::iterator &it, - Value &comparisonResult); - - static float getBaseReductionValue(); - - static int8_t getBaseReductionIntValue(); - - ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} -}; - -class ArgMaxConverter : public ArgMinMaxBaseConverter { -public: - static LogicalResult matchComparisonResult(Value currValue, Value currIndex, - Value reduceValue, - Value reduceIndex, - mlir::Block::iterator &it, - Value &comparisonResult); - - static float getBaseReductionValue(); - - static int8_t getBaseReductionIntValue(); - - ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} -}; - -} // namespace TTOpConverters - -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h deleted file mode 100644 index c3ac76a2e..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h +++ /dev/null @@ -1,237 +0,0 @@ -#ifndef TRITON_ANALYSIS_BLOCKPTRANALYSIS_H -#define TRITON_ANALYSIS_BLOCKPTRANALYSIS_H - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Value.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" - -#include -namespace mlir { - -class ConversionPatternRewriter; - -namespace triton { - -enum class MemAccVal { Undefined = 0, StrucMemAcc = 1, UnstrucMemAcc = 2 }; - -struct MemAccType { - - MemAccVal value; - - explicit constexpr MemAccType(MemAccVal v = MemAccVal::Undefined) - : value(v) {} - - constexpr operator MemAccVal() const { return value; } - explicit operator bool() = delete; - - constexpr bool isUndefined() const { return value == MemAccVal::Undefined; } - constexpr bool isStructured() const { - return value == MemAccVal::StrucMemAcc; - } - constexpr bool isUnstructured() const { - return value == MemAccVal::UnstrucMemAcc; - } - - void merge(MemAccType &other) { - this->value = (this->value > other.value) ? this->value : other.value; - } - - std::string_view toString() const { - static constexpr std::string_view names[] = {"Undefined", "StrucMemAcc", - "UnstrucMemAcc"}; - return names[static_cast(value)]; - } -}; - -class BlockData { -public: - SmallVector &getOffsetsRef(); - SmallVector &getSizesRef(); - SmallVector &getStridesRef(); - Value &getSourceRef(); - Value &getScalarRef(); - Type &getResElemTyRef(); - MemAccType &getMemAccTypeRef(); - - SmallVector getOffsets() const; - SmallVector getSizes() const; - SmallVector getStrides() const; - Type getResElemTy() const; - OpFoldResult getOffset(int) const; - OpFoldResult getSize(int) const; - OpFoldResult getStride(int) const; - Value getScalar() const; - Value getSource() const; - MemAccType getMemAccType() const; - - bool isScalar() const; - bool isEmpty() const; - bool hasSource() const; - bool hasResElemTy() const; - void removeSource(); - - int64_t getRank() const; - MemRefType getResultMemrefType(int64_t offset, ArrayRef resultShape, - bool DynamicStrides = false) const; - - void addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter); - void mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter); - void divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter); - - memref::ReinterpretCastOp createCastOp(ArrayRef resultShape, - const Location &loc, - OpBuilder &builder) const; - - void setResElemTy(const Type &); - void setSource(const Value &); - void setScalar(const Value &); - void setOffsets(const SmallVector &); - void setStrides(const SmallVector &); - void setSizes(const SmallVector &); - void setMemAccTy(const MemAccType &); - void setMemAccVal(const MemAccVal); - - void dump() const; - -private: - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - Value source; - Value scalar; - Type resElemTy; - MemAccType memAccTy; - - OpFoldResult inferBlockOffset(const Location &loc, OpBuilder &builder) const; -}; - -class BlockDataParser { -public: - using IndexMapSet = std::map>; - - static Value getScalarMemRef(Value ptr, Value memref, const Location &loc, - ConversionPatternRewriter &rewriter); - - static void parse(Value operand, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseAdd(arith::AddIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseMul(arith::MulIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseDiv(arith::DivSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseRem(arith::RemSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseUnrealizedCast(UnrealizedConversionCastOp op, BlockData &data, - const Location &loc, ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseMakeRange(triton::MakeRangeOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseExpandDims(triton::ExpandDimsOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseBitcast(triton::BitcastOp op, BlockData &data, - const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseExtSI(arith::ExtSIOp op, BlockData &data, - const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseBroadcast(triton::BroadcastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseSplat(triton::SplatOp op, BlockData &data, - const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseConstSplat(arith::ConstantOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseMakeTensorPtr(triton::MakeTensorPtrOp op, BlockData &data, - const Location &loc, ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseAddPtr(triton::AddPtrOp op, BlockData &data, - const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void - parseReinterpretCast(memref::ReinterpretCastOp op, BlockData &data, - const Location &loc, ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void parseReduce(triton::ReduceOp op, BlockData &data, - const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - - static void rewriteAddPtr(triton::AddPtrOp op, - triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known); - - static void rewriteAdvanceOp(triton::AdvanceOp op, - ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known); - - static void - rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, - const IndexMapSet &levelToBlockArgIndex, const int level, - const llvm::SmallDenseMap &known); - - static void rewriteForOp(scf::ForOp op, ConversionPatternRewriter &rewriter, - IndexMapSet &levelToBlockArgIndex, const int level, - llvm::SmallDenseMap &known); - - static void rewriteAddPtrToUnstrucMemAcc(triton::AddPtrOp op, - triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, - BlockData &data); -}; - -template -void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known); - -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index d62f670bb..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) -add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h deleted file mode 100644 index b1c4c7601..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h +++ /dev/null @@ -1,111 +0,0 @@ -#ifndef CONVERSIONPATTERNS_H -#define CONVERSIONPATTERNS_H - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" - -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" - -#include -#include -#include - -using namespace mlir; -using namespace triton; - -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -static Value getScalarValue(Value operand, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector ops; - - auto reconstructScalarValue = [&](Value src) { - for (auto op = ops.rbegin(); op != ops.rend(); ++op) { - src = TypeSwitch(*op) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Default([](Operation *op) { - llvm_unreachable("unsupported op in generating "); - return nullptr; - }); - } - return src; - }; - - while (true) { - if (!dyn_cast(operand.getType())) { - return reconstructScalarValue(operand); - } else if (auto op = operand.getDefiningOp()) { - if (auto attr = dyn_cast(op.getValue())) { - if (!attr.isSplat()) { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load " - "produced by unsupported instruction"; - return nullptr; - } - auto elemValue = attr.getSplatValue(); - auto constOp = arith::ConstantOp::materialize( - rewriter, elemValue, attr.getElementType(), op.getLoc()); - return reconstructScalarValue(constOp.getResult()); - } - } else if (auto op = operand.getDefiningOp()) { - operand = op.getSrc(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load produced " - "by unsupported instruction"; - return nullptr; - } - } - return nullptr; -} - -static SmallVector getNParallelLoopsAttrs(unsigned n) { - return SmallVector(n, utils::IteratorType::parallel); -} - -// for IntLike and FloatLike types -static std::optional getBitWidth(Type a) { - if (auto type = dyn_cast(a)) { - auto elementType = type.getElementType(); - if (elementType.isIntOrFloat()) { - return type.getElementType().getIntOrFloatBitWidth(); - } - return std::nullopt; - } - - if (a.isIntOrFloat()) { - return a.getIntOrFloatBitWidth(); - } - return std::nullopt; -} -#endif // CONVERSIONPATTERNS_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h deleted file mode 100644 index 33166ea0e..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef TRITON_ADAPTER_FUNCTIONCONVERTER_H -#define TRITON_ADAPTER_FUNCTIONCONVERTER_H - -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace FunctionConverter { -using namespace mlir; -using namespace triton; - -class GetProgramIDConverter - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static uint32_t constexpr LAUNCH_GRID_RANK = - getMaxEnumValForProgramIDDim() + 1; - -public: - LogicalResult - matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class GetNumProgramsConverter - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static uint32_t constexpr LAUNCH_GRID_RANK = - getMaxEnumValForProgramIDDim() + 1; - -public: - LogicalResult - matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace FunctionConverter -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h deleted file mode 100644 index 3a2701af0..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h +++ /dev/null @@ -1,196 +0,0 @@ -#ifndef TRITON_ADAPTER_LOADSTORECONVERTER_H -#define TRITON_ADAPTER_LOADSTORECONVERTER_H - -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Arith/Utils/Utils.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace LoadStoreConverter { - -using namespace mlir; -using namespace triton; - -class AddPtrConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class LoadConverter : public OpConversionPattern { -private: - LogicalResult toTensorAndReplace(triton::LoadOp &op, - RankedTensorType &tensorType, - memref::AllocOp &allocOp, - const Location &loc, - ConversionPatternRewriter &rewriter) const; - - LogicalResult checkModifiedByAddPtrConverter(triton::LoadOp &op) const; - - LogicalResult - continueModifyFromAddPtrConverter(triton::LoadOp &op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const; - -public: - explicit LoadConverter(MLIRContext *context); - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -// tempate class's impl must in header file -template -class LoadStoreCanonicalizer : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - Value ptrVal = op.getPtr(); - Type ptrTy = ptrVal.getType(); - auto ptrDefOp = ptrVal.getDefiningOp(); - if (isa(ptrVal)) - return failure(); - - if (!isTensorPointerType(ptrTy) && - !isa_and_nonnull(ptrDefOp)) { - if (isa(ptrDefOp)) { - auto castOp = cast(ptrDefOp); - auto castSrc = castOp.getSrc(); - if (!isa(castSrc)) { - auto castSrcDefOp = castSrc.getDefiningOp(); - if (isa(castSrcDefOp)) { - return rewriter.notifyMatchFailure( - op, "BitcastCanonicalizer handles addptr->bitcast->load!"); - } - } - } - - Type zeroTy = getI32SameShape(ptrTy); - Value zeroVal = - createScalarOrSplatConstant(rewriter, op.getLoc(), zeroTy, 0); - Value addptrVal = rewriter.create(op.getLoc(), ptrTy, - ptrVal, zeroVal); - rewriter.modifyOpInPlace( - op, [&]() { op->replaceUsesOfWith(ptrVal, addptrVal); }); - return success(); - } - return failure(); - } -}; - -class ScalarStoreCanonicalizer : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(triton::StoreOp op, - PatternRewriter &rewriter) const override; -}; - -class StoreConverter : public OpConversionPattern { -public: - explicit StoreConverter(MLIRContext *context); - - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ScalarAtomicRMWCanonicalizer - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::AtomicRMWOp op, - PatternRewriter &rewriter) const override; -}; - -class AtomicRMWConverter : public OpConversionPattern { -private: - Value createAtomicBinaryOps(OpBuilder &builder, Location loc, - triton::AtomicRMWOp op, Type elementType, - Value lhs, Value rhs) const { - auto rmwOp = op.getAtomicRmwOp(); - - // it has been confirmed in AtomicRMWConverter::matchAndRewrite - // that the ptr of op is of MemRefType - Value binaryOp; - if (rmwOp == triton::RMWOp::FADD) { - binaryOp = builder.create(loc, lhs, rhs); - } else if (rmwOp == triton::RMWOp::ADD) { - binaryOp = builder.create(loc, lhs, rhs); - } else if (rmwOp == triton::RMWOp::XOR) { - binaryOp = builder.create(loc, lhs, rhs); - } else if (rmwOp == triton::RMWOp::OR) { - binaryOp = builder.create(loc, lhs, rhs); - } else if (rmwOp == triton::RMWOp::AND) { - binaryOp = builder.create(loc, lhs, rhs); - } else if (rmwOp == triton::RMWOp::MAX) { - // Max/Min only support f32/i32 for now - // Other type is not supported because of semantic.py - if (isa(elementType)) { - binaryOp = builder.create(loc, lhs, rhs); - } else { - binaryOp = builder.create(loc, lhs, rhs); - } - } else if (rmwOp == triton::RMWOp::MIN) { - if (isa(elementType)) { - binaryOp = builder.create(loc, lhs, rhs); - } else { - binaryOp = builder.create(loc, lhs, rhs); - } - } else { - op.emitOpError("unsupported atomic RMW operation: "); - llvm_unreachable( - "Not implemented. Support fadd, add, max, min for now !"); - } - return binaryOp; - } - - // used when handling scalar - // to verify whether we need to handle this scalar - bool isConstantMaskTrue(Value mask) const { - if (auto denseAttr = - mask.getDefiningOp()->getAttrOfType("value")) { - auto eleType = denseAttr.getType().getElementType(); - if (isa(eleType) && - cast(eleType).getWidth() == 1) { - auto values = denseAttr.getValues(); - return values[0]; - } - } - return false; - } - - DenseSet softwareAtomicKinds = { - triton::RMWOp::AND, triton::RMWOp::OR, triton::RMWOp::XOR}; - -public: - explicit AtomicRMWConverter(MLIRContext *context); - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class AtomicMaxMinCanonicalizer : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::AtomicRMWOp op, - PatternRewriter &rewriter) const override; -}; - -} // namespace LoadStoreConverter -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h deleted file mode 100644 index 5df57dbbc..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h +++ /dev/null @@ -1,133 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation, Meta Platforms. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_ANALYSIS_MASKANALYSIS_H -#define TRITON_ANALYSIS_MASKANALYSIS_H - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include - -namespace mlir { - -// this class helps build Operations -class OpBuilder; - -namespace triton { -// use to decode the pattern in a mask used for load and store - -class MaskState { -public: - OpFoldResult start; - OpFoldResult end; - SmallVector dims; - SmallVector offsets; - OpFoldResult scalar; - - int64_t getRank() const { - assert(dims.size() == offsets.size() && "dims and offsets rank mismatch!"); - return dims.size(); - } - - bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } - - bool isMask() const { - return !start && !end && !scalar && dims.size() != 0 && offsets.size() != 0; - } - - // parse value recursively - LogicalResult parse(Value operand, const Location &loc, OpBuilder &builder); - - tensor::ExtractSliceOp getExtractSlice(Value source, const Location &loc, - OpBuilder &builder) const; - - tensor::InsertSliceOp getInsertSlice(Value source, Value dest, - const Location &loc, - OpBuilder &builder) const; - - memref::SubViewOp getSubview(Value source, const Location &loc, - OpBuilder &builder) const; - - std::pair - getSideBySideSubviews(Value block1, Value block2, const Location &loc, - OpBuilder &builder) const; - - std::pair - getStackedSubviews(Value block1, Value block2, const Location &loc, - OpBuilder &builder) const; - - void eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter); - -private: - // Utility functions - LogicalResult addStateScalar(const MaskState &state, - const OpFoldResult scalar, const Location &loc, - OpBuilder &builder); - - LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, - const Location &loc, OpBuilder &builder); - - LogicalResult divStateScalar(const MaskState &state, - const OpFoldResult scalar, const Location &loc, - OpBuilder &builder); - - LogicalResult divStates(const MaskState &lhsState, const MaskState &rhsState, - const Location &loc, OpBuilder &builder); - - LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, - const Location &loc, OpBuilder &builder); - - // Helper functions to parse values to populate MaskState - - LogicalResult parseConstant(arith::ConstantOp constOp, const Location &loc, - OpBuilder &builder); - - // Operand is an integer scalar - LogicalResult parseIntScalar(Value scalar, const Location &loc, - OpBuilder &builder); - - // TODO - LogicalResult parseAdd(arith::AddIOp addOp, const Location &loc, - OpBuilder &builder); - - // operand is the result of divsi - LogicalResult parseDiv(arith::DivSIOp divOp, const Location &loc, - OpBuilder &builder); - - // Operand is the result of andi - LogicalResult parseAnd(arith::AndIOp andOp, const Location &loc, - OpBuilder &builder); - - // Operand is the result of cmpi - LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location &loc, - OpBuilder &builder); - - // Operand is the result of make_range - LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location &loc, - OpBuilder &builder); - - // Operand is the result of broadcast - LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, - const Location &loc, OpBuilder &builder); - - // Operand is the result of splat - LogicalResult parseSplat(triton::SplatOp splatOp, const Location &loc, - OpBuilder &builder); - - // Operand is the result of expand_dims - LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, - const Location &loc, OpBuilder &builder); -}; - -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h deleted file mode 100644 index d3623a5cb..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H -#define TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H - -#include "TritonToLinalgPass.h" - -namespace mlir { -namespace triton { - -#define GEN_PASS_REGISTRATION -#include "TritonToLinalg/Passes.h.inc" - -} // namespace triton -} // namespace mlir - -#endif // TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td deleted file mode 100644 index 6ae983b57..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef TRITON_TO_LINALG_CONVERSION_PASSES -#define TRITON_TO_LINALG_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { - let summary = "Convert Triton to Linalg dialect"; - let constructor = "triton::createTritonToLinalgPass()"; - let options = [ - Option<"globalKernel", "global-kernel", - "bool", /*default*/"true", - "generate a global kernel">, - Option<"namedOps", "named-ops", - "bool", /*default*/"false", - "use linalg named ops instead of linalg.generic"> - ]; -} - -#endif // TRITON_TO_LINALG_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h deleted file mode 100644 index b078e4ff4..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h +++ /dev/null @@ -1,378 +0,0 @@ -#ifndef TRITON_ADAPTER_TRITONOPCONVERTER_H -#define TRITON_ADAPTER_TRITONOPCONVERTER_H - -#include "TritonToLinalg/BlockPtrAnalysis.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/SmallVector.h" - -#define DEBUG_TYPE "triton-to-linalg" - -#include "llvm/Support/Debug.h" - -namespace TTOpConverters { -using namespace mlir; -using namespace triton; - -/* -Convert `tt.precise_div` operation to `arith.divf` operation. -tensor_x / tensor_y - -```ttir - %11 = tt.precise_divf %7, %10 : tensor<100xf32> -``` - -converts to: - -```mlir - %11 = arith.divf %7, %10 : tensor<100xf32> -``` -*/ -struct PreciseDivConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/* - * Rewrite arith.select with contiguouse mask to - * tensor.extract_slice/insert_slice. - */ -class SelectConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::SelectOp op, - PatternRewriter &rewriter) const override; -}; - -/* - * Move tt.bitcast to a previous location if tt.bitcast is not directly applied - * on function arguments - */ -class BitcastCanonicalizer : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::BitcastOp bitcastOp, - PatternRewriter &rewriter) const override; -}; - -template -class ScalarMathCanonicalizer : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MathOp op, - PatternRewriter &rewriter) const override { - if (op->getNumResults() != 1) { - return rewriter.notifyMatchFailure( - op, "ScalarMathCanonicalizer expects single scalar output."); - } - if (!op->getResult(0).getType().isIntOrIndexOrFloat()) { - return rewriter.notifyMatchFailure( - op, "ScalarMathCanonicalizer handles scalar load scene."); - } - if (auto linalgOp = op->template getParentOfType()) { - return rewriter.notifyMatchFailure( - op, "ScalarMathCanonicalizer handles op not within tt.reduce."); - } - auto loc = op.getLoc(); - llvm::SmallVector inputs; - for (auto input : op->getOperands()) { - auto blkTy = RankedTensorType::get({(int64_t)1}, input.getType()); - auto inputSplat = rewriter.create(loc, blkTy, input); - inputs.push_back(inputSplat.getResult()); - } - auto blkOp = rewriter.create(loc, inputs); - Value offset = - rewriter.create(loc, rewriter.getIndexAttr(0)); - auto extractOp = - rewriter.create(loc, blkOp.getResult(), offset); - rewriter.replaceOp(op, extractOp); - return success(); - } -}; - -class DenseConstantConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class MakeRangeConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class SplatConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ReshapeConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ExpandDimsConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ClampFConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class BroadcastConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ReduceConverter : public OpConversionPattern { -public: - explicit ReduceConverter(MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - -private: - llvm::SmallVector getRedOps(triton::ReduceOp redOp) const; - - bool isReductionOpSupported(Operation *redOp) const; - - arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, - Operation *redOp, - Type constantType) const; - - bool requiresF32Conversion(const Type elemType, Operation *redOp) const; - - Value getRedElement(Value lhs, Value rhs, const Location loc, - Operation *redOp, OpBuilder &b, - const bool convertLhsToF32Precision) const; - - LogicalResult - convertToLinalgReduce(triton::ReduceOp op, - typename triton::ReduceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const; - - LogicalResult - convertToLinalgReduceExtended(ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const; - -public: - LogicalResult - matchAndRewrite(triton::ReduceOp op, - typename triton::ReduceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ExternElementwiseClOpConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class UnrealizedCastConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class JoinConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class SplitConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class CatConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class GatherConverter : public OpConversionPattern { -private: - static constexpr llvm::StringRef gatherFuncNameBase = "triton_gather"; - -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class YieldConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class LoopConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class AdvanceConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class MakeTensorPtrConverter - : public OpConversionPattern { -private: - using OpConversionPattern::OpConversionPattern; - - void populateVectorAsIndex(SmallVector &vec, - Operation::operand_range ops, - ConversionPatternRewriter &rewriter, - Location loc) const; - - memref::ReinterpretCastOp - createRedundantOp(triton::MakeTensorPtrOp op, - ConversionPatternRewriter &rewriter, BlockData &data) const; - - OpFoldResult - accumulatePotentialOffsetOnBase(triton::MakeTensorPtrOp op, Value base, - OpFoldResult offset, - ConversionPatternRewriter &rewriter) const; - -public: - explicit MakeTensorPtrConverter(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult - matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class TransposeConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class BitcastConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class TritonMulhiuiConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class TritonPreciseSqrtConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class AssertConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(triton::AssertOp op, - PatternRewriter &rewriter) const override; -}; - -class DevicePrintConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -private: - static constexpr llvm::StringRef printFuncNameBase = "triton_print"; - static constexpr llvm::StringRef prefixAttrName = "prefix"; - static constexpr llvm::StringRef hexAttrName = "hex"; - -public: - LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -struct MatmulConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -} // end of namespace TTOpConverters - -#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h deleted file mode 100644 index 6fc57aaad..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h +++ /dev/null @@ -1,71 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H -#define TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#define GEN_PASS_CLASSES -#include "../../include/TritonToLinalg/Passes.h.inc" - -namespace mlir { -namespace triton { - -std::unique_ptr> createTritonToLinalgPass(); - -} // namespace triton -} // namespace mlir - -namespace { - -using namespace mlir; -using namespace triton; -const std::string globalKernelAttr = "global_kernel"; -const std::string kernelMixModeName = "mix_mode"; - -class TritonTypeConverter : public mlir::TypeConverter { -public: - explicit TritonTypeConverter(); -}; - -class TritonToLinalgPass : public TritonToLinalgBase { - - static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; - static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = - LAUNCH_GRID_RANK * 2; - -private: - // grid构造 num_programs 3维, program_id 3维 - // remember 'xxxOp' is usually a Pointer, so that we can change target memory - // without giving a reference argument - void addProgramInfo(triton::FuncOp func, bool globalKernel); - - void convertTTFunc(triton::FuncOp func, const bool existDot); - - void addDynamicLegal(ConversionTarget &target, - TritonTypeConverter &tritonTypeConverter); - - void - populateTritonToLinalgCanonicalizationPatterns(RewritePatternSet &patterns); - - void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - unsigned int launchGridRank); - -public: - void getDependentDialects(DialectRegistry ®istry) const override; - - void runOnOperation() override; -}; -} // namespace - -#endif // TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h deleted file mode 100644 index e2727fa4c..000000000 --- a/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h +++ /dev/null @@ -1,128 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_ANALYSIS_USEANALYSIS_H -#define TRITON_ANALYSIS_USEANALYSIS_H - -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" - -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace mlir { -namespace triton { - -enum class UseType { - Undefined, // Initial state - DataUse, // value used for tensor computation only - MetaUse, // value used for metadata only - MixUse // value used for both tensor computation and metadata -}; - -struct UseInfo : public dataflow::AbstractSparseLattice { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UseInfo) - using AbstractSparseLattice::AbstractSparseLattice; - - // Lattice state transfer function - ChangeResult meetUseType(const UseType &other) { - if (other == UseType::Undefined) { - return ChangeResult::NoChange; - } - - switch (type) { - case UseType::Undefined: - type = other; - return ChangeResult::Change; - case UseType::DataUse: - case UseType::MetaUse: - if (type == other) { - return ChangeResult::NoChange; - } else { - type = UseType::MixUse; - return ChangeResult::Change; - } - case UseType::MixUse: - return ChangeResult::NoChange; - default: - llvm_unreachable("bad type"); - } - } - - ChangeResult meet(const AbstractSparseLattice &other) override { - auto rhs = reinterpret_cast(&other); - return meetUseType(rhs->type); - } - - void print(raw_ostream &os) const override { - switch (type) { - case UseType::DataUse: - os << "DataUse"; - break; - case UseType::MetaUse: - os << "MetaUse"; - break; - case UseType::MixUse: - os << "MixUse"; - break; - default: - os << "Undefined"; - } - } - - UseType type = UseType::Undefined; -}; - -class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { -public: - using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; - -#if LLVM_VERSION_MAJOR >= 20 - LogicalResult visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; -#else - void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; -#endif - - void visitBranchOperand(OpOperand &operand) override { return; } - - void visitCallOperand(OpOperand &operand) override { return; } - - void setToExitState(UseInfo *lattice) override { - lattice->type = UseType::Undefined; - } - -private: - void propagateUse(UseInfo *lattice, const UseType &type) { - auto changed = lattice->meetUseType(type); - propagateIfChanged(lattice, changed); - } - - void propagateResults(UseInfo *lattice, ArrayRef results) { - auto changed = ChangeResult::NoChange; - for (auto result : results) { - changed |= lattice->meet(*result); - } - propagateIfChanged(lattice, changed); - } -}; - -class MetaUseEraser : public RewritePattern { -public: - MetaUseEraser(MLIRContext *context); - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final; -}; - -LogicalResult runUseAnalysis(triton::FuncOp &funcOp); - -} // namespace triton - -} // namespace mlir - -#endif // TRITON_CONVERSION_TRITONTOAFFINE_TRITONUSEANALYSIS_H diff --git a/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h b/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h deleted file mode 100644 index 8e445d7ee..000000000 --- a/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once - -#include "TritonToLinalg/BlockPtrAnalysis.h" -#include "TritonToLinalg/MaskAnalysis.h" -#include "TritonToLinalg/UseAnalysis.h" -#include "Utils/Utils.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" - -#include -#include -#include -#include -#include - -namespace mlir { -namespace triton { - -enum class IndexMode : int { EVEN_MODE = 0, ODD_MODE = 1 }; - -MemRefType expandInterleaveMemRefType(MemRefType originType); - -std::pair -recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder); - -LogicalResult -DeinterleaveStatusOptimization(triton::LoadOp op, - triton::LoadOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter); - -LogicalResult DeinterleaveStatusWithMaskOptimization( - triton::LoadOp op, triton::LoadOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter, MaskState &mstate, - memref::AllocOp originAllocOp); - -LogicalResult -InterleaveStatusOptimization(SmallVector materializeVec); - -LogicalResult -InterleaveStatusWithMaskOptimization(SmallVector materializeVec); - -} // namespace triton -} // namespace mlir diff --git a/third_party/ascend/triton-adapter/include/Utils/Utils.h b/third_party/ascend/triton-adapter/include/Utils/Utils.h deleted file mode 100644 index fb713d45f..000000000 --- a/third_party/ascend/triton-adapter/include/Utils/Utils.h +++ /dev/null @@ -1,148 +0,0 @@ -#ifndef TRITONNPU_UTILS_UTILS_H -#define TRITONNPU_UTILS_UTILS_H - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" - -#include -#include - -namespace mlir { - -namespace ConverterUtils { - -Value getTransposedValue(Value source, const Location loc, - ConversionPatternRewriter &rewriter, - llvm::ArrayRef order); - -SmallVector getNParallelLoopsAttrs(unsigned n); - -Value getScalarValue(Value operand, Location loc, - ConversionPatternRewriter &rewriter); - -memref::SubViewOp makeSubViewOp(Value src, - llvm::SmallVectorImpl &sizes, - const Location &loc, - ConversionPatternRewriter &rewriter); - -void getShapeInfo(Value val, llvm::SmallVectorImpl &shapes, - ConversionPatternRewriter &rewriter); - -SmallVector -getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, - Value adaptorPtr, const Location &loc, - ConversionPatternRewriter &rewriter); - -SmallVector getBroadcastDims(RankedTensorType src, - RankedTensorType dst); - -SmallVector getUnbroadcastDims(RankedTensorType src, - RankedTensorType dst); - -} // namespace ConverterUtils - -class ConversionPatternRewriter; - -namespace triton { - -enum class IndirectLoadInterfaceOpType { Undefined = 0, Load = 1, Calc = 2 }; - -// Traceback from rootOp to find the targetOp with the specified condition -mlir::Operation * -findFirstMatchingOperandDef(mlir::Operation *rootOp, - const std::function &condFn); - -void traverseBackwardUpdateOperandChainIf( - Operation *op, std::function conditionFn, - std::function actionFn, OpBuilder &builder); - -void traverseBackwardUpdateOperandChainIf( - Operation *rootOp, std::function conditionFn, - std::function actionFn); - -void traverseForwardUpdateUserChainIf( - Operation *op, std::function conditionFn, - std::function stopFn, - std::function actionFn, OpBuilder &builder, - llvm::SmallPtrSet &stopOps); - -void traverseForwardUpdateUserChainIf( - Operation *rootOp, std::function conditionFn, - std::function stopFn, - std::function actionFn, - llvm::SmallPtrSet &stopOps); - -// UseAnalysis will tag operations whose results are used only as meta-data -// with "MetaUse" tag. -bool isMetaUse(Operation *op); - -bool isMixUse(Operation *op); - -IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op); - -bool opIsIndirectLoad(Operation *op); - -bool opIsIndirectCalc(Operation *op); - -scf::ForOp createNestedLoops( - OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, - ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, - ValueRange initArgs, - function_ref &, ValueRange)> - bodyBuilder); - -ModuleOp getModuleOpFromOperation(Operation *op); - -} // namespace triton - -class OpBuilder; - -std::optional makeIntAttr(const OpFoldResult &ofr); - -bool hasConstantZero(const OpFoldResult &ofr); - -Value opFoldResultToIndex(const OpFoldResult &ofr, const Location &loc, - OpBuilder &b); - -SmallVector opFoldResultToIndex(ArrayRef ofrs, - const Location &loc, OpBuilder &b); - -Value createConstIntOp(const Location &loc, OpBuilder &b, int64_t value); - -OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const Value &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b); - -LogicalResult -addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, - linalg::ReduceOp reduceOp); - -} // namespace mlir - -#endif // TRITONNPU_UTILS_UTILS_H diff --git a/third_party/ascend/triton-adapter/lib/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/CMakeLists.txt deleted file mode 100644 index cbf0d9d7e..000000000 --- a/third_party/ascend/triton-adapter/lib/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(TritonToLinalg) -add_subdirectory(Utils) diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp deleted file mode 100644 index 9bf7f7ca2..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "TritonToLinalg/ArgMinMaxConverter.h" - -namespace TTOpConverters { -using namespace mlir; -using namespace triton; - -// ArgMinConverter functions -LogicalResult ArgMinConverter::matchComparisonResult( - Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, Value &comparisonResult) { - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - - auto cmpOp = dyn_cast(*it); - auto cmpIOp = dyn_cast(*it++); - if (!cmpOp && !cmpIOp) - return failure(); - - if (cmpOp) { - if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || - currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { - return failure(); - } - comparisonResult = cmpOp; - } - - if (cmpIOp) { - if (cmpIOp.getPredicate() != arith::CmpIPredicate::slt || - currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { - return failure(); - } - comparisonResult = cmpIOp; - } - - return success(); -} - -float ArgMinConverter::getBaseReductionValue() { - return std::numeric_limits::infinity(); -} - -int8_t ArgMinConverter::getBaseReductionIntValue() { return 127; } - -// ArgMaxConverter functions -LogicalResult ArgMaxConverter::matchComparisonResult( - Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, Value &comparisonResult) { - auto cmpOp = dyn_cast(*it); - auto cmpIOp = dyn_cast(*it++); - if (!cmpOp && !cmpIOp) - return failure(); - - if (cmpOp) { - if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || - currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { - return failure(); - } - comparisonResult = cmpOp; - } - - if (cmpIOp) { - if (cmpIOp.getPredicate() != arith::CmpIPredicate::sgt || - currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { - return failure(); - } - comparisonResult = cmpIOp; - } - - return success(); -} - -float ArgMaxConverter::getBaseReductionValue() { - return -std::numeric_limits::infinity(); -} - -int8_t ArgMaxConverter::getBaseReductionIntValue() { return -128; } - -} // namespace TTOpConverters diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp deleted file mode 100644 index 26b8658e3..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp +++ /dev/null @@ -1,1404 +0,0 @@ -#include "TritonToLinalg/BlockPtrAnalysis.h" -#include "Utils/Utils.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include - -#define DEBUG_TYPE "triton-block-ptr-analysis" - -namespace mlir { -namespace triton { - -// MemAccType selectMaxMemAccTy(const MemAccType &v1, const MemAccType &v2) { -// return (v1 > v2) ? v1 : v2; -// } - -namespace { -void assertLegalUnrealizedCast(UnrealizedConversionCastOp op) { - assert(op && op.getInputs().size() == 3 && - op.getInputs()[0].getDefiningOp() && - op.getInputs()[1].getDefiningOp() && - op.getInputs()[1].getDefiningOp()); -} -} // namespace - -SmallVector &BlockData::getOffsetsRef() { return this->offsets; } - -SmallVector &BlockData::getSizesRef() { return this->sizes; } - -SmallVector &BlockData::getStridesRef() { return this->strides; } - -Value &BlockData::getSourceRef() { return this->source; } - -Value &BlockData::getScalarRef() { return this->scalar; } - -SmallVector BlockData::getOffsets() const { - return this->offsets; -} - -SmallVector BlockData::getSizes() const { return this->sizes; } - -SmallVector BlockData::getStrides() const { - return this->strides; -} - -OpFoldResult BlockData::getOffset(int index) const { - return this->offsets[index]; -} - -OpFoldResult BlockData::getSize(int index) const { return this->sizes[index]; } - -OpFoldResult BlockData::getStride(int index) const { - return this->strides[index]; -} - -Value BlockData::getScalar() const { return this->scalar; } - -Value BlockData::getSource() const { return this->source; } - -MemAccType BlockData::getMemAccType() const { return this->memAccTy; }; - -MemAccType &BlockData::getMemAccTypeRef() { return this->memAccTy; }; - -bool BlockData::isScalar() const { return this->scalar != nullptr; } - -bool BlockData::isEmpty() const { - return !(this->getRank() || this->source || this->scalar); -} - -bool BlockData::hasSource() const { return this->source != nullptr; } - -void BlockData::removeSource() { this->source = nullptr; }; - -bool BlockData::hasResElemTy() const { return this->resElemTy != nullptr; } - -Type &BlockData::getResElemTyRef() { return this->resElemTy; } - -Type BlockData::getResElemTy() const { return this->resElemTy; } - -int64_t BlockData::getRank() const { - assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); - return this->offsets.size(); -} - -void BlockData::setResElemTy(const Type &Ty) { this->resElemTy = Ty; } - -void BlockData::setScalar(const Value &scalar) { this->scalar = scalar; } - -void BlockData::setSource(const Value &src) { this->source = src; } - -void BlockData::setOffsets(const SmallVector &offsets) { - this->offsets = offsets; -} - -void BlockData::setStrides(const SmallVector &strides) { - this->strides = strides; -} - -void BlockData::setSizes(const SmallVector &szs) { - this->sizes = szs; -} - -void BlockData::setMemAccTy(const MemAccType &v) { this->memAccTy = v; } - -void BlockData::setMemAccVal(const MemAccVal v) { this->memAccTy.value = v; } - -OpFoldResult BlockData::inferBlockOffset(const Location &loc, - OpBuilder &builder) const { - OpFoldResult retOffset = builder.getIndexAttr(0); - for (auto ofr : offsets) { - retOffset = addOpFoldResult(retOffset, ofr, loc, builder); - } - return retOffset; -} - -MemRefType BlockData::getResultMemrefType(int64_t offset, - ArrayRef resultShape, - bool DynamicStrides) const { - SmallVector staticStrides; - if (DynamicStrides) { - staticStrides.append(this->strides.size(), ShapedType::kDynamic); - } else { - SmallVector dynamicStrides; - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - } - - auto elementType = - dyn_cast(this->source.getType()).getElementType(); - auto layout = - StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides); - return MemRefType::get(resultShape, elementType, layout); -} - -void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - // When both left block and right block have source, it is indirect load. - assert(!(lBlock.hasSource() && rBlock.hasSource())); - this->source = - lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef(); - - assert(!rBlock.hasResElemTy()); - if (lBlock.hasResElemTy()) { - this->resElemTy = lBlock.getResElemTyRef(); - } - - if (lBlock.isScalar() && rBlock.isScalar()) { - auto addOp = rewriter.create(loc, lBlock.getScalarRef(), - rBlock.getScalarRef()); - this->scalar = addOp.getResult(); - } else if (lBlock.getRank() == 0) { - this->scalar = - lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef(); - } - - for (const auto &[lOffset, rOffset] : - llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { - this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter)); - } - - for (const auto &[lStride, rStride] : - llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { - this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter)); - } - - this->sizes = lBlock.getSizesRef(); - - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); -} - -void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - - assert(!(lBlock.hasSource() && rBlock.hasSource())); - - assert( - (lBlock.isScalar() ^ rBlock.isScalar()) && - "Currently only support one and only one scalar in function mulBlock()"); - BlockData *lb = &lBlock; - BlockData *rb = &rBlock; - if (lb->isScalar()) { - std::swap(lb, rb); - } - - Value rScalar = rb->getScalarRef(); - for (const auto &lOffset : lb->getOffsetsRef()) { - this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter)); - } - - for (const auto &lStride : lb->getStridesRef()) { - this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter)); - } - - this->sizes = lb->getSizesRef(); - - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); -} - -void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - - assert(!(lBlock.hasSource() && rBlock.hasSource())); - - for (const auto &[lOffset, rOffset] : - llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { - this->offsets.push_back(divOpFoldResult(lOffset, rOffset, loc, rewriter)); - } - - for (const auto &[lStride, rStride] : - llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { - this->strides.push_back(divOpFoldResult(lStride, rStride, loc, rewriter)); - } - - this->sizes = lBlock.getSizesRef(); - - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); -} - -memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef resultShape, - const Location &loc, - OpBuilder &builder) const { - OpFoldResult resultOffset = this->inferBlockOffset(loc, builder); - SmallVector staticOffset; - SmallVector dynamicOffset; - dispatchIndexOpFoldResult(resultOffset, dynamicOffset, staticOffset); - - auto resultType = this->getResultMemrefType(staticOffset[0], resultShape); - - return builder.create( - loc, resultType, this->source, resultOffset, this->sizes, this->strides); -} - -void BlockData::dump() const { - llvm::outs() << "[INFO][BEG] BlockData info\n"; - llvm::outs() << "offsets has " << offsets.size() << " items\n"; - int cnt = 0; - for (auto it = offsets.begin(); it != offsets.end(); it++) { - llvm::outs() << "offsets[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "sizes has " << sizes.size() << " items\n"; - cnt = 0; - for (auto it = sizes.begin(); it != sizes.end(); it++) { - llvm::outs() << "sizes[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "strides has " << strides.size() << " items\n"; - cnt = 0; - for (auto it = strides.begin(); it != strides.end(); it++) { - llvm::outs() << "strides[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "source = " << source << "\n"; - llvm::outs() << "scalar = " << scalar << "\n"; - llvm::outs() << "resElemTy = " << resElemTy << "\n"; - llvm::outs() << "memAccTy = " << memAccTy.toString() << "\n"; - llvm::outs() << "[INFO][END] BlockData info\n"; -} - -Value BlockDataParser::getScalarMemRef(Value ptr, Value memref, - const Location &loc, - ConversionPatternRewriter &rewriter) { - assert(isa(ptr.getType()) && "expect a scalar pointer"); - if (ptr.getDefiningOp()) { - if (auto castOp = memref.getDefiningOp()) { - return castOp.getResult(); - } else { - llvm_unreachable("pointer value is defined by an unexpected op"); - } - } - - assert(isa(ptr) && - "pointer should be produced by addptr or block argument"); - BlockData data; - data.setSource(memref); - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(1)); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - auto castOp = data.createCastOp(SmallVector(1, 1), loc, rewriter); - return castOp.getResult(); -} - -void BlockDataParser::parse( - Value operand, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - if (known.find(operand) != known.end()) { - return data = known.lookup(operand), void(); - } - - if (isa(operand.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), operand); - return data.setScalar(castOp.getResult()), void(); - } - - if (isa(operand.getType())) { - auto remappedPtr = rewriter.getRemappedValue(operand); - assert(remappedPtr); - if (auto op = operand.getDefiningOp()) { - if (auto addPtrOp = dyn_cast(op)) { - parseAddPtr(addPtrOp, data, loc, rewriter, known); - } else if (auto makeTensorPtrOp = dyn_cast(op)) { - parseMakeTensorPtr(makeTensorPtrOp, data, loc, rewriter, known); - } else if (auto bitcastOp = dyn_cast(op)) { - parseBitcast(bitcastOp, data, loc, rewriter, known); - } else { - llvm_unreachable("Unexpected operand defining operation,A scalar " - "pointer can only be produced by AddPtrOp or a block"); - } - } else { - data.setSource(remappedPtr); - } - return; - } - - // not a scalar pointer - if (auto addOp = operand.getDefiningOp()) { - parseAdd(addOp, data, loc, rewriter, known); - } else if (auto mulOp = operand.getDefiningOp()) { - parseMul(mulOp, data, loc, rewriter, known); - } else if (auto addPtrOp = operand.getDefiningOp()) { - parseAddPtr(addPtrOp, data, loc, rewriter, known); - } else if (auto constOp = operand.getDefiningOp()) { - parseConstSplat(constOp, data, loc, rewriter, known); - } else if (auto broadcastOp = operand.getDefiningOp()) { - parseBroadcast(broadcastOp, data, loc, rewriter, known); - } else if (auto splatOp = operand.getDefiningOp()) { - parseSplat(splatOp, data, loc, rewriter, known); - } else if (auto expandDimsOp = - operand.getDefiningOp()) { - parseExpandDims(expandDimsOp, data, loc, rewriter, known); - } else if (auto remOp = operand.getDefiningOp()) { - parseRem(remOp, data, loc, rewriter, known); - } else if (auto bitcastOp = operand.getDefiningOp()) { - parseBitcast(bitcastOp, data, loc, rewriter, known); - } else if (auto extsiOp = operand.getDefiningOp()) { - parseExtSI(extsiOp, data, loc, rewriter, known); - } else if (auto divOp = operand.getDefiningOp()) { - parseDiv(divOp, data, loc, rewriter, known); - } else if (auto makeRangeOp = operand.getDefiningOp()) { - parseMakeRange(makeRangeOp, data, loc, rewriter, known); - } else if (auto reduceOp = operand.getDefiningOp()) { - parseReduce(reduceOp, data, loc, rewriter, known); - } else if (auto loadOp = operand.getDefiningOp()) { - parseIndirectLoad(loadOp, data, loc, rewriter, known); - } else if (auto castOp = operand.getDefiningOp()) { - parseIndirectLoad(castOp, data, loc, rewriter, known); - } else { - operand.dump(); - llvm_unreachable("encountered AddPtrOp produced by unsupported operation"); - } -} - -void BlockDataParser::parseAdd( - arith::AddIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); - data.addBlock(lBlock, rBlock, loc, rewriter); -} - -void BlockDataParser::parseMul( - arith::MulIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); - data.mulBlock(lBlock, rBlock, loc, rewriter); -} - -void BlockDataParser::parseDiv( - arith::DivSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); - data.divBlock(lBlock, rBlock, loc, rewriter); -} - -// TODO : support modulos -void BlockDataParser::parseRem( - arith::RemSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(false && "Address expression with modulo is not supported yet, it " - "shall be analysis at linearize."); -} - -void BlockDataParser::parseUnrealizedCast( - UnrealizedConversionCastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assertLegalUnrealizedCast(op); - - auto originBlock = op.getInputs()[2]; - if (known.contains(originBlock)) { - data = known.at(originBlock); - } else { - parseAddPtr(originBlock.getDefiningOp(), data, loc, - rewriter, known); - } -} - -void BlockDataParser::parseMakeRange( - triton::MakeRangeOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - auto shape = dyn_cast(op.getType()).getShape(); - - auto start = op.getStart(); - auto end = op.getEnd(); - auto stride = (end >= start) && (end - start <= shape[0]); - assert(stride == 1 && - "make_range op should always return a tensor of stride 1"); - - data.getOffsetsRef().push_back(rewriter.getIndexAttr(start)); - data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); - data.getStridesRef().push_back(rewriter.getIndexAttr(stride)); -} - -void BlockDataParser::parseExpandDims( - triton::ExpandDimsOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - parse(op.getSrcMutable().get(), data, loc, rewriter, known); - auto resShape = dyn_cast(op.getResult().getType()).getShape(); - auto axis = op.getAxis(); - - assert(resShape[axis] == 1 && - "The destiny shape of changed dimension should be 1"); - - data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis, - rewriter.getIndexAttr(0)); - data.getSizesRef().insert(data.getSizesRef().begin() + axis, - rewriter.getIndexAttr(1)); - data.getStridesRef().insert(data.getStridesRef().begin() + axis, - rewriter.getIndexAttr(0)); -} - -void BlockDataParser::parseBitcast( - triton::BitcastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - parse(op.getSrc(), data, loc, rewriter, known); - - auto resType = op.getResult().getType(); - mlir::Type resElemPointeeTy; - if (auto resShapedTy = dyn_cast(resType)) { - auto resElemTy = resShapedTy.getElementType(); - resElemPointeeTy = - dyn_cast(resElemTy).getPointeeType(); - } else { - resElemPointeeTy = dyn_cast(resType).getPointeeType(); - } - data.setResElemTy(resElemPointeeTy); -} - -void BlockDataParser::parseExtSI( - arith::ExtSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - parse(op.getIn(), data, loc, rewriter, known); -} - -void BlockDataParser::parseBroadcast( - triton::BroadcastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - auto src = op.getSrcMutable().get(); - auto dst = op.getResult(); - assert(isa(src.getType()) && - "tt.broadcast's input should be a tensor"); - - auto srcShape = dyn_cast(src.getType()).getShape(); - auto dstShape = dyn_cast(dst.getType()).getShape(); - assert(srcShape.size() == dstShape.size() && - "rank of source shoule be equal to destnation"); - - parse(src, data, loc, rewriter, known); - - auto &blockSizes = data.getSizesRef(); - for (const auto &[idx, src_dst] : - llvm::enumerate(llvm::zip(srcShape, dstShape))) { - const auto &[srcAxis, dstAxis] = src_dst; - if (srcAxis == dstAxis) { - continue; - } - assert(srcAxis < dstAxis && - "srcShape of broadcastOp must be less than dstShape."); - blockSizes[idx] = rewriter.getIndexAttr(dstAxis); - } -} - -void BlockDataParser::parseSplat( - triton::SplatOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - auto src = op.getSrc(); - auto dst = op.getResult(); - auto dstShape = dyn_cast(dst.getType()).getShape(); - - parse(src, data, loc, rewriter, known); - - if (isa(src.getType()) || - isa(src.getType())) { - if (!data.isEmpty()) { - data.getOffsetsRef().clear(); - data.getSizesRef().clear(); - data.getStridesRef().clear(); - } - for (auto dstAxis : dstShape) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } - } else { - auto srcType = dyn_cast(src.getType()); - assert(srcType.getRank() == 1 && data.getRank() == 1 && - "splat MemRef source should have rank 1"); - assert(srcType.getShape()[0] == 1 && - makeIntAttr(data.getSizesRef()[0]).value() == 1 && - "splat MemRef source shoule have size 1"); - data.getStridesRef()[0] = rewriter.getIndexAttr(0); - - for (const auto &[idx, dstAxis] : llvm::enumerate(dstShape)) { - if (idx == 0) { - data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis); - continue; - } - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } - } - if (data.isScalar()) { - data.getOffsetsRef()[0] = data.getScalarRef(); - } -} - -void BlockDataParser::parseConstSplat( - arith::ConstantOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - auto attr = dyn_cast(op.getValue()); - auto elementType = attr.getElementType(); - assert(attr.isSplat() && isa(elementType)); - - auto val = attr.getValues()[0].getValue(); - auto constAttr = rewriter.getIndexAttr(val.getSExtValue()); - auto constOp = arith::ConstantOp::materialize(rewriter, constAttr, - rewriter.getIndexType(), loc); - data.setScalar(constOp); - - auto resType = dyn_cast(op.getResult().getType()); - size_t loopLimit = resType.getShape().size(); - for (auto i = 0; i < loopLimit; i++) { - if (i == 0) { - data.getOffsetsRef().push_back(constOp.getResult()); - } else { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - } - data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i])); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } -} - -void BlockDataParser::parseMakeTensorPtr( - triton::MakeTensorPtrOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - auto remappedValue = rewriter.getRemappedValue(op); - if (auto castOp = remappedValue.getDefiningOp()) { - parseReinterpretCast(castOp, data, loc, rewriter, known); - } else { - llvm_unreachable("the value should be mapped to memref.reinterpret_cast"); - } -} - -void BlockDataParser::parseAddPtr( - triton::AddPtrOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - BlockData ptrBlock, offsetBlock; - parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known); - parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known); - - assert(ptrBlock.hasSource() && - "Ptr field should provide source/base pointer"); - // offset has source means offset is from tl.load and other ops(TODO) - if (offsetBlock.hasSource()) { - ptrBlock.setMemAccTy(offsetBlock.getMemAccType()); - offsetBlock.removeSource(); - } - - // handle for loop & scalar - if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) { - offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1)); - offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef()); - offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } - - assert(ptrBlock.getRank() == offsetBlock.getRank() && - "ptr and offset should have same rank"); - LLVM_DEBUG({ - auto &os = llvm::dbgs(); - os << "[parseAddPtr][BEG] =========================\n"; - os << "[parseAddPtr] op is " << op << "\n"; - for (int i = 0; i < ptrBlock.getRank(); i++) { - os << "ptrBlock.getOffsetsRef()[" << i - << "] = " << ptrBlock.getOffsetsRef()[i] << "\n"; - os << "ptrBlock.getSizesRef()[" << i - << "] = " << ptrBlock.getSizesRef()[i] << "\n"; - os << "ptrBlock.getStridesRef()[" << i - << "] = " << ptrBlock.getStridesRef()[i] << "\n"; - os << "offsetBlock.getOffsetsRef()[" << i - << "] = " << offsetBlock.getOffsetsRef()[i] << "\n"; - os << "offsetBlock.getSizesRef()[" << i - << "] = " << offsetBlock.getSizesRef()[i] << "\n"; - os << "offsetBlock.getStridesRef()[" << i - << "] = " << offsetBlock.getStridesRef()[i] << "\n"; - } - os << "[parseAddPtr][END] -------------------------\n"; - }); - data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter); -} - -void BlockDataParser::parseReinterpretCast( - memref::ReinterpretCastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - data.setOffsets(op.getMixedOffsets()); - data.setSizes(op.getMixedSizes()); - data.setStrides(op.getMixedStrides()); - data.setSource(op.getSource()); - - assert(data.getOffsetsRef().size() == 1); - size_t loopLimit = data.getSizesRef().size(); - for (size_t i = 1; i < loopLimit; i++) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - } - - loopLimit = data.getStridesRef().size(); - for (size_t i = 0; i < loopLimit; i++) { - auto strideIntAttr = makeIntAttr(data.getStridesRef()[i]); - auto sizeIntAttr = makeIntAttr(data.getSizesRef()[i]); - assert(sizeIntAttr); - if (sizeIntAttr.value() == 1 && strideIntAttr) { - data.getStridesRef()[i] = rewriter.getIndexAttr(0); - } - } -} - -void BlockDataParser::parseReduce( - triton::ReduceOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - - const std::string scenarioMessages = - "PtsAnalysis supports indirectly block load in the following scenario\n" - "B = tl.load(Aptr + Aoffset) # B is 1D tensor\n" - "s = tl.min(B) # s is a scalar\n" - "D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n"; - - auto reduce_src = op->getOperand(0); - BlockData srcBlock; - parse(reduce_src, srcBlock, loc, rewriter, known); - if (!srcBlock.hasSource()) { - llvm_unreachable(scenarioMessages.c_str()); - } - if (!isa(srcBlock.getSource().getDefiningOp())) { - llvm_unreachable(scenarioMessages.c_str()); - } - - auto reduce_result = op->getResult(0); - auto shaped_ty = dyn_cast(reduce_result.getType()); - auto shape = shaped_ty.getShape(); - auto ops = llvm::map_to_vector(op.getBody()->without_terminator(), - [](Operation &op) { return &op; }); - // Support only the case: scalar = tl.load(1D tensor) - if (shape.size() != 1 || op.getAxis() != 0 || ops.size() != 1 || - !isa(ops.front())) { - llvm_unreachable(scenarioMessages.c_str()); - } - - auto castOp = rewriter.create( - loc, RankedTensorType::get(shape, rewriter.getIndexType()), - reduce_result); - auto offset = castOp.getResult(); - if (data.isEmpty()) { - data.getOffsetsRef().push_back(offset); - data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - } else { - llvm_unreachable("parseReduce with offset already setup not yet supported"); - } -} - -template -void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - // FIXME: assume single result of operation - auto opRes = op->getResult(0); - auto opResTy = opRes.getType(); - std::vector resShape; - if (auto shapedResTy = dyn_cast(opResTy)) { - // For now, we consider this is UnstrucMemAcc because we have no other info. - // Visiting other ops may change the type due to more info. - data.setMemAccVal(MemAccVal::UnstrucMemAcc); - resShape = shapedResTy.getShape().vec(); - } else { - // scalar load means this is used as offset. It is StrucMemAcc. - data.setMemAccVal(MemAccVal::StrucMemAcc); - resShape.push_back(1); - } - for (auto &s : resShape) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(s)); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - } - // set the source in BlockData so that we know an indirect-load op exists in - // the chain. - data.setSource(opRes); -} - -void BlockDataParser::rewriteAddPtr( - triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known) { - auto insertPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(op); - - BlockData data; - parseAddPtr(op, data, op.getLoc(), rewriter, known); - - if (data.getMemAccTypeRef().isUnstructured()) { - // TODO: Based on more info, try to create a performant IR - rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data); - LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); - return; - } - - if (data.getSizesRef().size() == 0) { - data.getSizesRef().push_back(rewriter.getIndexAttr(1)); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - data.getOffsetsRef().push_back(data.getScalarRef()); - } - - ArrayRef resultShape; - SmallVector staticShape1(1, 1); // sz 1 value 1 - if (auto shapedType = dyn_cast(op.getResult().getType())) { - resultShape = shapedType.getShape(); - } else { - assert(data.getRank() == 1); - resultShape = staticShape1; - } - - known[op.getResult()] = data; - - auto infered_size = 1; - for (int i = data.getSizesRef().size() - 1; i >= 0; i--) { - auto strideInt = makeIntAttr(data.getStridesRef()[i]); - auto sizeInt = makeIntAttr(data.getSizesRef()[i]); - assert(sizeInt); - if (sizeInt.value() == 1 && strideInt && strideInt.value() == 0) { - data.getStridesRef()[i] = rewriter.getIndexAttr(infered_size); - } - infered_size *= sizeInt.value(); - } - - if (data.hasResElemTy()) { - auto memrefType = dyn_cast(data.getSourceRef().getType()) - .cloneWith(std::nullopt, data.getResElemTyRef()); - UnrealizedConversionCastOp castOp = - rewriter.create( - op.getLoc(), memrefType, data.getSourceRef()); - data.setSource(castOp.getOutputs()[0]); - } - - // no module handle - memref::ReinterpretCastOp castOp = - data.createCastOp(resultShape, op.getLoc(), rewriter); - Value src = castOp.getResult(); - LLVM_DEBUG({ - llvm::dbgs() << "cast MemRefType:\n"; - castOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); - - data.setSource(src); - rewriter.replaceOp(op, src); - rewriter.restoreInsertionPoint(insertPoint); -} - -void BlockDataParser::rewriteAdvanceOp( - triton::AdvanceOp op, ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known) { - OpBuilder::InsertionGuard insertionGuard{rewriter}; - rewriter.setInsertionPoint(op); - auto loc = op.getLoc(); - - BlockData blockData; - parse(op.getOperand(0), blockData, loc, rewriter, known); - - auto incrementOffsets = op.getOffsets(); - - SmallVector newOffsets; - for (const auto [increment, offset, stride] : - llvm::zip(incrementOffsets, blockData.getOffsetsRef(), - blockData.getStridesRef())) { - Value offsetValue; - if (auto offsetIntAttr = makeIntAttr(offset)) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); - offsetValue = constOp.getResult(); - } else { - offsetValue = offset.get(); - } - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), increment); - auto mulOp = rewriter.create(loc, castOp.getResult(), - stride.get()); - auto addOp = - rewriter.create(loc, mulOp.getResult(), offsetValue); - newOffsets.push_back(addOp.getResult()); - } - - blockData.getOffsetsRef().clear(); - - for (auto offset : newOffsets) { - blockData.getOffsetsRef().push_back(offset); - } - - SmallVector scalarShape(1, 1); - ArrayRef resultShape; - auto pointerType = cast(op.getResult().getType()); - if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { - resultShape = shapedType.getShape(); - } else { - // scalar pointer, should produce a one dimensional memref - resultShape = scalarShape; - assert(blockData.getRank() == 1); - } - - auto newOp = blockData.createCastOp(resultShape, loc, rewriter); - - rewriter.replaceOp(op, newOp.getResult()); - - known[newOp.getResult()] = blockData; -} - -void BlockDataParser::rewriteYieldOp( - scf::YieldOp op, ConversionPatternRewriter &rewriter, - const IndexMapSet &levelToBlockArgIndex, const int level, - const llvm::SmallDenseMap &known) { - // any inserted instruction should be before this yield - OpBuilder::InsertionGuard insertionGuard{rewriter}; - rewriter.setInsertionPoint(op); - - auto adaptor = scf::YieldOp::Adaptor(op); - - SmallVector initArgState; - SmallVector operands(adaptor.getOperands()); - // Track the second chunks of modulo pointers so that we can append them to - // the yield results - SmallVector moduloSecondChunks; - - // For each of the init arg that we added additional Values in for loop, we - // need to add corresponding Values as yield operands. The loop below gathers - // BlockData for those values. - for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { - if (auto mappedV = rewriter.getRemappedValue(v)) { - // If this value is a tensor of pointers produced by AddPtrOp, - // we should have already converted to a ReinterpretCastOp without - // layout information for the normal cases, or to an - // UnrealizedConversionCastOp for the split pointer case. - if (v.getDefiningOp() || - v.getDefiningOp() || - v.getDefiningOp()) { - if (auto castOp = mappedV.getDefiningOp()) { - assertLegalUnrealizedCast(castOp); - auto castInputs = castOp.getInputs(); - v = castOp.getResult(0); - operands[i] = castInputs[0]; - moduloSecondChunks.push_back(castInputs[1]); - } else if (auto castOp = - mappedV.getDefiningOp()) { - v = castOp; - } else { - llvm_unreachable("mapped value defined by an unexpected op"); - } - } else { - // If this value is not a tensor of pointers, we will use the - // mapped value, and rely on the conversion will happen later - // automatically when we legalize loop body. - - // TODO: - // The scenario where a value is a tensor of pointers but not - // produced by AddPtrOp is not supported - if (isa(mappedV.getType()) && - isa( - dyn_cast(mappedV.getType()).getElementType())) - llvm_unreachable("unsupported scenario where a value is a tensor of " - "pointers but not produced by AddPtrOp"); - v = mappedV; - } - } - - if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) - continue; - auto thisSet = levelToBlockArgIndex.find(level)->second; - if (thisSet.find(i) == thisSet.end()) - continue; - - auto reintCastOp = v.getDefiningOp(); - auto unrealizedCastOp = v.getDefiningOp(); - - // assert condition deleted: (unrealizedCastOp && - // unrealizedCastOp->hasAttr(ModuloState::WraparoundAttr)) - assert( - reintCastOp || - (isa(v.getType()) && - isa(dyn_cast(v.getType()).getElementType()))); - - BlockData state; - if (reintCastOp) { - parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known); - } else if (unrealizedCastOp) { - assertLegalUnrealizedCast(unrealizedCastOp); - parseUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, - known); - } else { - parse(v, state, op.getLoc(), rewriter, known); - } - initArgState.push_back(state); - } - - // For each of the BlockData recorded in the last step, extract value - // that correspond to offset and stride for each dimension and append - // them to yield operands. - for (auto state : initArgState) { - for (auto s : state.getOffsetsRef()) { - // offsets can be IntAttr zeroes, since reinterpret_cast collapses - // them for the input memref, and the for loop may not update - // offsets other than offsets[0]. Create constants Values for those - // zeroes. - if (auto sIntAttr = makeIntAttr(s)) { - assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); - operands.push_back(constOp.getResult()); - } else { - operands.push_back(s.get()); - } - } - - for (auto s : state.getStridesRef()) { - assert(!makeIntAttr(s) && "BlockData strides for yield within for " - "loop not expected to be " - "attribute."); - operands.push_back(s.get()); - } - } - - for (auto chunk : moduloSecondChunks) { - operands.push_back(chunk); - } - - // Yield is a terminator op that must be at the end of the function - rewriter.setInsertionPointAfter(op); - auto newOp = rewriter.replaceOpWithNewOp(op, operands); - assert(op->getNumResults() == 0); - - LLVM_DEBUG({ - llvm::dbgs() << "new yield:"; - newOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); -} - -namespace { - -struct ModuloChunkInitArg { - Value reinterpretCast = nullptr; - // where in the init args is the first chunk placed - size_t initArgIndex = -1; -}; - -} // namespace - -void BlockDataParser::rewriteForOp( - scf::ForOp op, ConversionPatternRewriter &rewriter, - IndexMapSet &levelToBlockArgIndex, const int level, - llvm::SmallDenseMap &known) { - SmallVector newInitArgs; - - SmallVector, 5> initArgIndexState; - SmallVector, 5> knownPtrsTmp; - - // If we have a load op that uses a modulo pointer, we need to insert both of - // the memref chunks to the init args. We reuse the sizes from the original - // memrefs. This data structure keeps track of where these additional init - // args should be inserted. - // - // As an example, if we have a 2D memrefs being split, we first put the first - // chunk in the order as it appears. Then, once all of the original init args - // are processed, we insert their offsets and strides, and finally the second - // chunk. - SmallVector, BlockData>, - 6> - moduloStates; - - // Amongst the init args, track the indices that map to the first chunk of a - // modulo pair. This is used to distinguish between the normal - // reinterpret_casts whose return types need to be rewritten to match what the - // for loop is yielding. - DenseSet moduloInitArgIndices; - - // Create a new list of init args - for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { - auto mappedV = rewriter.getRemappedValue(arg); - memref::ReinterpretCastOp reintCastOp; - UnrealizedConversionCastOp unrealizedCastOp; - - // If this init arg is supposed to be remapped, use the remapped - // value instead. In addition, if this init arg is a memref created - // by a reinterpret_cast or a tensor of index, there is a chance that - // it will be used in addptr. Create BlockData for each such init arg. - if (mappedV) { - // TODO: - // Passing a block argument pointer directly into a for loop not - // supported. - assert(!(dyn_cast(mappedV) && - isa(mappedV.getType())) && - "cannot take pointer block argument as init arg for for loop"); - if (auto op = mappedV.getDefiningOp()) { - reintCastOp = op; - newInitArgs.push_back(mappedV); - } else if (auto op = - mappedV.getDefiningOp()) { - assertLegalUnrealizedCast(op); - unrealizedCastOp = op; - auto inputs = unrealizedCastOp.getInputs(); - - SmallVector initArgData{ - ModuloChunkInitArg{inputs[0], i}, - ModuloChunkInitArg{inputs[1]}, - }; - - moduloInitArgIndices.insert(i); - moduloStates.push_back( - std::make_tuple(unrealizedCastOp, initArgData, BlockData{})); - - newInitArgs.push_back(inputs[0]); - } else { - newInitArgs.push_back(mappedV); - } - - } else { - newInitArgs.push_back(arg); - } - - auto indexTensor = - isa(arg.getType()) && - isa(dyn_cast(arg.getType()).getElementType()); - - if (!unrealizedCastOp && !reintCastOp && !indexTensor) - continue; - - BlockData data; - if (reintCastOp) { - parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter, - llvm::SmallDenseMap(0)); - } else if (unrealizedCastOp) { - parseUnrealizedCast(unrealizedCastOp, data, op.getLoc(), rewriter, - llvm::SmallDenseMap(0)); - std::get<2>(moduloStates.back()) = data; - } else { - parse(arg, data, op.getLoc(), rewriter, - llvm::SmallDenseMap(0)); - } - - // Record the BlockData for later processing - initArgIndexState.push_back(std::make_pair(i, data)); - } - - // Set insertion point to be before the for loop for new variables passed - // into the new loop. - auto origIp = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(op); - - // For each of the BlockData recorded in the last step, insert new - // instructions to describe offset and stride for each dimension and append - // them to init args - for (auto [i, data] : initArgIndexState) { - // For each dimension, if the corresponding offset and stride is an - // integer attribute, create a constant value and append them at the - // end of init arg list. - for (auto [j, s] : llvm::enumerate(data.getOffsetsRef())) { - auto sIntAttr = makeIntAttr(s); - if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); - newInitArgs.push_back(constOp.getResult()); - data.getOffsetsRef()[j] = constOp.getResult(); - } else { - newInitArgs.push_back(s.get()); - } - } - - for (auto [j, s] : llvm::enumerate(data.getStridesRef())) { - auto sIntAttr = makeIntAttr(s); - if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); - newInitArgs.push_back(constOp.getResult()); - data.getStridesRef()[j] = constOp.getResult(); - } else { - newInitArgs.push_back(s.get()); - } - } - - // Note that we want the knownPtrs to be indexed by block arg, but we - // only have index for now. Also, the blockdata we record is the init - // arg, but want to to use newly created block arg. These block args - // are not created yet. We will translate this mapping later. - knownPtrsTmp.push_back(std::make_pair(i, data)); - levelToBlockArgIndex[level].insert(i); - - // If the original init arg is a memref produced by reinterpret_cast, - // create a new memref using new strides and offsets created above. - // This produces a canonicalized memref, which will match what the - // for loop generates if it modifies the memref. E.g., original - // reinterpret_cast can produce a memref with const stride: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + - // s0 + d1 - // * s1)>> - // The new reinterpret_cast will always have dynamic stride and - // offset: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 - // + s0 + d1 * s2)>> - // - // For init args that are the first chunk of a modulo pair, there is - // no need for the type to be rewritten because the strides and - // offsets are already dynamic. - if (!moduloInitArgIndices.contains(i) && - newInitArgs[i].getDefiningOp()) { - SmallVector resultShape; - for (auto s : data.getSizesRef()) { - auto sIntAttr = makeIntAttr(s); - assert(sIntAttr && "expected constant size"); - resultShape.push_back(sIntAttr.value()); - } - auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); - - LLVM_DEBUG({ - llvm::dbgs() << "new reinterpret_cast with dynamic sizes " - "and offsets:"; - castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); - - newInitArgs[i] = castOp.getResult(); - } - } - - // Pass in the second chunk of each modulo pair - for (auto &[unrealizedCastOp, chunkData, data] : moduloStates) { - chunkData[1].initArgIndex = newInitArgs.size(); - newInitArgs.push_back(chunkData[1].reinterpretCast); - } - - rewriter.restoreInsertionPoint(origIp); - - // Create a new scf::ForOp that uses updated init args and same loop body - auto newOp = rewriter.create( - op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), - newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - IRMapping mapping; - mapping.map(op.getInductionVar(), iv); - mapping.map(op.getInitArgs(), newInitArgs); - mapping.map(op.getRegionIterArgs(), args); - - for (auto &bodyOp : op.getRegion().getOps()) { - b.clone(bodyOp, mapping); - } - - // Load op is lowered independent of the pointer, if we have a split - // pointer due to modulo, we need to "logically combine" these two - // memrefs into a single one using unrealized_cast_op. This way, when - // lowering the load, the pattern can detect if additional copies are - // inserted. When we are in a loop, it is more complicated because we - // have to insert a new unrealized_cast_op that combines the two memrefs - // in the init arg list. In addition, because init args hold no offset - // and size information, we have to manually insert two additional - // reinterpret_cast ops as input to this unrealized_cast_op so that the - // load have enough information to generate the corresponding copy. - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(b.getBlock()); - - Value zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - - for (auto &[unrealizedCastOp, chunkData, data] : moduloStates) { - SmallVector newReinterpretCasts; - for (auto &chunk : chunkData) { - newReinterpretCasts.push_back(args[chunk.initArgIndex]); - } - - auto combinedCast = b.create( - loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, - unrealizedCastOp->getAttrs()); - - args[chunkData[0].initArgIndex].replaceUsesWithIf( - combinedCast.getResult(0), [](OpOperand &operand) { - assert(!isa(operand.getOwner()) && - "Storing to split pointers not supported"); - return isa(operand.getOwner()); - }); - } - }); - - // Convert the book-keeping data structure to use the correct key and value. - // Key is converted from init arg index to newly created block arg, and - // Value's BlockData fields are converted from init arg to newly created block - // arg - int cnt = op.getRegionIterArgs().size(); - for (auto [i, data] : knownPtrsTmp) { - for (auto it = data.getOffsetsRef().begin(); - it != data.getOffsetsRef().end(); it++) { - *it = newOp.getRegionIterArgs()[cnt]; - cnt++; - } - - for (auto it = data.getStridesRef().begin(); - it != data.getStridesRef().end(); it++) { - *it = newOp.getRegionIterArgs()[cnt]; - cnt++; - } - - auto key = newOp.getRegionIterArgs()[i]; - known.insert(std::make_pair(key, data)); - } - assert(static_cast(cnt + moduloStates.size()) == - newOp.getRegionIterArgs().size() && - "expect to remap all new block args"); - - // Replace only the results that correspond to the original scf.for - auto resultsToReplaceWith = ResultRange( - newOp.result_begin(), newOp.result_begin() + op.getNumResults()); - rewriter.replaceOp(op, resultsToReplaceWith); - - // Update the loop body. Manually invoke the rewrite logic on addptr and yield - // in the loop body, so we can take advantage of the states we built up - for (auto &bodyOp : newOp.getRegion().getOps()) { - if (auto addptrOp = dyn_cast(bodyOp)) { - // FIXME: Constructed adaptor here does not hold the transformed op info. - auto adaptor = triton::AddPtrOp::Adaptor(addptrOp); - rewriteAddPtr(addptrOp, adaptor, rewriter, known); - } else if (auto advanceOp = dyn_cast(bodyOp)) { - rewriteAdvanceOp(advanceOp, rewriter, known); - } else if (auto forOp = dyn_cast(bodyOp)) { - // TODO: - // Nested for loops are not supported at the moment - assert(0 && "nested loops currently not supported"); - // rewriteForOp(forOp, rewriter, levelToBlockArgIndex, level+1, - // knownPtrs); levelToBlockArgIndex.erase(level+1); - } - } - - if (op.getNumRegionIterArgs()) { - auto yieldOp = cast(newOp.getBody()->getTerminator()); - rewriteYieldOp(yieldOp, rewriter, levelToBlockArgIndex, level, known); - } - - LLVM_DEBUG({ - llvm::dbgs() << "new for\n"; - newOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); -} - -/// @brief Rewrite the triton::AddPtrOp to handle unstructured memory access. -/// @param op The triton::AddPtrOp to be rewritten. -/// @param adaptor The adaptor of the triton::AddPtrOp, used to get operands. -/// @param rewriter The pattern rewriter used to modify the IR. -/// @param data The BlockData containing information about the memory access. -void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( - triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, BlockData &data) { - auto loc = op.getLoc(); - auto &offsets = data.getOffsetsRef(); - auto &blockSizes = data.getSizesRef(); - auto &strides = data.getStridesRef(); - Value ptrOffset = adaptor.getOffset(); - Value zeroIdx = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value oneIdx = - rewriter.create(loc, rewriter.getIndexAttr(1)); - auto addptrRes = op.getResult(); - assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); - auto loadOp = *(addptrRes.user_begin()); - - // Prepare empty tensor for loop based scalar load - // FIXME: We use cast here because addptr must return tensor>. - // True? - auto resTy = cast(addptrRes.getType()); - auto resEPtrTy = resTy.getElementType(); - auto resETy = cast(resEPtrTy).getPointeeType(); - Value loaded = rewriter.create(loc, blockSizes, resETy); - SmallVector initArgs; - initArgs.push_back(loaded); - - SmallVector forLBs; - SmallVector forUBs; - SmallVector forSteps; - for (auto &s : offsets) { - forLBs.push_back(zeroIdx); - } - for (auto &s : blockSizes) { - forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); - } - for (auto &s : strides) { - forSteps.push_back(oneIdx); - } - SmallVector ivs; - OpBuilder builder(op); - auto loop = createNestedLoops( - builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, - initArgs, - [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, - ValueRange iterArgs) { - OpBuilder::InsertionGuard g(bB); - bB.setInsertionPointToStart(bB.getBlock()); - - Value scalarOffsetRaw = - bB.create(bLoc, ptrOffset, allIVs); - Value scalarOffset = bB.create( - bLoc, bB.getIndexType(), scalarOffsetRaw); - // Replace offset & size. Only single element. - data.getOffsetsRef().clear(); - data.getOffsetsRef().push_back(scalarOffset); - data.getSizesRef().clear(); - data.getSizesRef().push_back(bB.getIndexAttr(1)); - data.getStridesRef().clear(); - data.getStridesRef().push_back(bB.getIndexAttr(1)); - memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); - rewriter.replaceOp(op, castOp); - // Move tt.load using this tt.addptr into this block - loadOp->moveAfter(castOp); - loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); - bB.create(bLoc, iterArgs); - }); -} - -} // namespace triton -} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index 75f5ad897..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -add_triton_library(TritonToLinalg - TritonToLinalgPass.cpp - LoadStoreConverter.cpp - FunctionConverter.cpp - ArgMinMaxConverter.cpp - TritonOpConverter.cpp - BlockPtrAnalysis.cpp - MaskAnalysis.cpp - UseAnalysis.cpp - - DEPENDS - TritonToLinalgConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRArithDialect - MLIRDialectUtils - MLIRIR - MLIRMathDialect - MLIRPass - MLIRTensorDialect - MLIRTransforms - MLIRSupport - TritonIR - TritonTransforms - TritonAnalysis - MLIRTritonNPUUtils - MLIRSCFTransforms - MLIRLinalgTransforms -) diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp deleted file mode 100644 index af58b6dbe..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "TritonToLinalg/FunctionConverter.h" - -namespace FunctionConverter { -using namespace mlir; -using namespace triton; - -LogicalResult GetProgramIDConverter::matchAndRewrite( - triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto axis = (uint32_t)op.getAxis(); - assert(axis < GetProgramIDConverter::LAUNCH_GRID_RANK && - "Invalid axis for GetProgramIdOp"); - auto func = op->getParentOfType(); - auto numArgs = func.getNumArguments(); - auto id = func.getArgument(numArgs - GetProgramIDConverter::LAUNCH_GRID_RANK + - axis); - rewriter.replaceOp(op, id); - return success(); -} - -LogicalResult GetNumProgramsConverter::matchAndRewrite( - triton::GetNumProgramsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto axis = (uint32_t)op.getAxis(); - assert(axis < GetNumProgramsConverter::LAUNCH_GRID_RANK && - "Invalid axis for GetNumProgramsOp"); - auto func = op->getParentOfType(); - auto numArgs = func.getNumArguments(); - auto id = func.getArgument( - numArgs - GetNumProgramsConverter::LAUNCH_GRID_RANK * 2 + axis); - rewriter.replaceOp(op, id); - return success(); -} -} // namespace FunctionConverter diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp deleted file mode 100644 index 2d49a6ab8..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp +++ /dev/null @@ -1,752 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "TritonToLinalg/LoadStoreConverter.h" -#include "TritonToLinalg/BlockPtrAnalysis.h" -#include "TritonToLinalg/MaskAnalysis.h" -#include "Utils/InterleaveOptimization.h" -#include "Utils/Utils.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" - -#include "llvm/Support/Debug.h" - -#include -#include -#include - -#define DEBUG_TYPE "triton-load-store-converter" - -namespace LoadStoreConverter { -using namespace mlir; -using namespace triton; - -LogicalResult -AddPtrConverter::matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - llvm::SmallDenseMap known; - BlockDataParser::rewriteAddPtr(op, adaptor, rewriter, known); - return success(); -} - -LogicalResult LoadConverter::toTensorAndReplace( - triton::LoadOp &op, RankedTensorType &tensorType, memref::AllocOp &allocOp, - const Location &loc, ConversionPatternRewriter &rewriter) const { - Value loadedTensor = rewriter.create( - loc, tensorType, allocOp, true, true); - rewriter.replaceOp(op, loadedTensor); - return success(); -} - -/// @brief Check whether the triton::LoadOp has been modified to the specified -/// state by the AddPtrConverter. -/// @param op The triton::LoadOp operation to be checked. -/// @return Return success if the operation conforms to the specified state; -/// otherwise, return failure. -LogicalResult -LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const { - if (!isa(op->getParentOp())) { - return failure(); - } - if (!op->hasAttr("IndirectLoad")) { - return failure(); - } - auto ptrOp = op.getPtr().getDefiningOp(); - auto ptrBlock = ptrOp->getBlock(); - auto opBlock = op->getBlock(); - if (ptrBlock == opBlock) { - return failure(); - } - - return success(); -} - -/// @brief Continue to modify the triton::LoadOp from the state modified by the -/// AddPtrConverter. -/// @param op The triton::LoadOp operation to be processed. -/// @param adaptor The adaptor for the operation, used to obtain operands. -/// @param rewriter The pattern rewriter used to rewrite the operation. -/// @return Return success if the operation is successful; otherwise, return -/// failure. -LogicalResult LoadConverter::continueModifyFromAddPtrConverter( - triton::LoadOp &op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto forOp = op->getParentOfType(); - Operation *firstOp = &forOp.getBody()->front(); - auto extractOp = cast(firstOp); - auto ivs = extractOp.getIndices(); - // Single iterArg which is inserted by AddPtrConverter. - auto iterArg = forOp.getRegionIterArg(0); - auto ptr = adaptor.getPtr(); - - rewriter.setInsertionPointAfter(op); - Value castVal = ptr.getDefiningOp(); - Value idxZero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value loadVal = - rewriter.create(loc, castVal, ValueRange{idxZero}); - Value insertedVal = - rewriter.create(loc, loadVal, iterArg, ValueRange{ivs}); - // a yield op is already created by AddPtrConverter. - // so we need to replace it with a new yield op. - Operation *terminator = forOp.getBody()->getTerminator(); - scf::YieldOp oldYieldOp = cast(terminator); - auto yieldOp = rewriter.create(loc, ValueRange{insertedVal}); - rewriter.replaceOp(oldYieldOp, yieldOp); - // Now the scf.for is complete, we can replace tt.load with it. - auto rank = cast(op.getResult().getType()).getShape().size(); - Operation *rootForOp = op; - while (rank != 0) { - rank--; - rootForOp = rootForOp->getParentOfType(); - } - rewriter.replaceOp(op, rootForOp); - LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(rootForOp) << "\n"; }); - return success(); -} - -LoadConverter::LoadConverter(MLIRContext *context) - : OpConversionPattern(context) {} - -LogicalResult -LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - // Check if tt.load is modified by AddPtrConverter to a specified state. - if (checkModifiedByAddPtrConverter(op).succeeded()) { - return continueModifyFromAddPtrConverter(op, adaptor, rewriter); - } - - auto ptr = adaptor.getPtr(); - auto mask = op.getMask(); - auto other = op.getOther(); - auto loc = op.getLoc(); - - // handling scalar - if (!isa(op.getResult().getType())) { - auto scalarMemref = - BlockDataParser::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); - auto resTy = op.getResult().getType(); - auto idxZero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - auto loadOp = rewriter.create(loc, resTy, scalarMemref, - idxZero.getResult()); - rewriter.replaceOp(op, loadOp.getResult()); - return success(); - } - - // handling no mask - auto memRefType = dyn_cast(ptr.getType()); - if (!memRefType) { - return rewriter.notifyMatchFailure( - op, "LoadOp expects a memref, not a memref of pointers"); - } - auto memRefShape = memRefType.getShape(); - auto memRefElementType = memRefType.getElementType(); - - auto allocOp = rewriter.create( - loc, MemRefType::get(memRefShape, memRefElementType)); - - auto tensorType = RankedTensorType::get(memRefShape, memRefElementType); - // boundary check - auto boundaryCheck = op.getBoundaryCheck(); - if (!boundaryCheck.empty()) { - auto boundarySizes = mlir::ConverterUtils::getBoundarySizes( - boundaryCheck, op.getPtr(), ptr, loc, rewriter); - // handle the padding - auto padding = op.getPadding(); - if (padding.has_value()) { - TypedAttr padAttr = rewriter.getZeroAttr(memRefElementType); - // triton already ensure only NAN and ZERO are passed in - if (padding.value() == triton::PaddingOption::PAD_NAN) { - // FIXME: Why NaN requires elemTy to be non-int or non-index? - assert(!memRefElementType.isIntOrIndex()); - auto apNaN = llvm::APFloat::getNaN( - cast(padAttr).getValue().getSemantics()); - padAttr = rewriter.getFloatAttr(memRefElementType, apNaN); - } - auto padVal = rewriter.create(loc, padAttr); - - auto shape = memRefType.getShape(); - auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) - .getResult(); - for (size_t i = 0; i < boundarySizes.size(); i++) { - auto dim = boundaryCheck[i]; - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[dim])); - Value bndSizei = dyn_cast(boundarySizes[i]); - if (!bndSizei) { - bndSizei = rewriter.create( - loc, cast(boundarySizes[i].get())); - } - auto cmpOp = rewriter.create( - loc, arith::CmpIPredicate::slt, bndSizei, shapei); - accBase = rewriter.create(loc, accBase, cmpOp.getResult()) - .getResult(); - } - rewriter.create( - loc, accBase, [&](OpBuilder &builder, Location loc) { - builder.create(loc, ValueRange{padVal}, - ValueRange{allocOp}); - builder.create(loc); - }); - } - - auto srcSubView = - mlir::ConverterUtils::makeSubViewOp(ptr, boundarySizes, loc, rewriter); - auto dstSubview = mlir::ConverterUtils::makeSubViewOp( - allocOp, boundarySizes, loc, rewriter); - rewriter.create(loc, srcSubView, dstSubview); - - return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); - } - - if (!mask) { - assert(!other && "can not input 'other' when 'mask' is not set"); - if (auto unrealizedCastOp = - ptr.getDefiningOp()) { - // TODO : not support handle associate with "module" - // hint : can be handled in Linearize - } else { - // If last dimension stride equals 2, try deinterleave optimization. - auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); - if (ptrStrides.back() == 2 && (memRefShape.back() % 2 == 0) && - mlir::triton::DeinterleaveStatusOptimization(op, adaptor, rewriter) - .succeeded()) { - return success(); - } - rewriter.create(loc, ptr, allocOp); - } - - return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); - } - - MaskState mstate; - auto isContMask = mstate.parse(mask, loc, rewriter); - if (isContMask.failed()) { - return rewriter.notifyMatchFailure( - op, "can not lower uncontinuout masked loads"); - } - - if (other) { - auto scalarOther = - mlir::ConverterUtils::getScalarValue(other, loc, rewriter); - assert( - scalarOther && - "other value used in masked load produced by unsupported instruction!"); - auto shape = memRefType.getShape(); - auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) - .getResult(); - for (size_t i = 0; i < memRefType.getShape().size(); i++) { - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[i])); - Value dimi = dyn_cast(mstate.dims[i]); - if (!dimi) { - dimi = rewriter.create( - loc, cast(mstate.dims[i].get())); - } - auto cmpOp = rewriter.create( - loc, arith::CmpIPredicate::slt, dimi, shapei); - accBase = rewriter.create(loc, accBase, cmpOp.getResult()) - .getResult(); - } - - rewriter.create( - loc, accBase, [&](OpBuilder &builder, Location loc) { - builder.create(loc, ValueRange{scalarOther}, - ValueRange{allocOp}); - builder.create(loc); - }); - } - - // To enable deinterleave optimization with mask load, mask state along last - // dimension couldn't be split, which means `dims.back()` must be equal to - // origin type last dimension constant size and `offsets.back()` must be 0. - // - // The basis is that last dimension range comparison would generate - // unaccepted discontinuous mask. - if (mstate.getRank() == memRefType.getRank() && - isConstantIntValue(mstate.offsets.back(), 0) && - isConstantIntValue(mstate.dims.back(), memRefType.getShape().back())) { - auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); - if (ptrStrides.back() == 2 && (memRefType.getShape().back() % 2 == 0) && - DeinterleaveStatusWithMaskOptimization(op, adaptor, rewriter, mstate, - allocOp) - .succeeded()) { - return success(); - } - } - - if (auto unrealizedCastOp = ptr.getDefiningOp()) { - // TODO : not support handle associate with "module" - // hint : can be handled in Linearize - } else { - memref::SubViewOp srcSubView = mstate.getSubview(ptr, loc, rewriter); - memref::SubViewOp dstSubView = mstate.getSubview(allocOp, loc, rewriter); - rewriter.create(loc, srcSubView, dstSubView); - } - return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); -} - -AtomicRMWConverter::AtomicRMWConverter(MLIRContext *context) - : OpConversionPattern(context) {} - -// lowering tt.atomicRMW to linalg.generic -// If atomic op's return value is used by other op as it's the old value stored -// at the ptrwe will use tt.load to get it -// -// example: -// input: -// %return_value = tt.atomic_rmw fadd, acq_rel, gpu, -// %output_memref, %input_tensor, %mask : -// (tensor<256x!tt.ptr>, tensor<256xf32>, tensor<256xi1>) -// -> tensor<256xf32> -// -// output: -// memref.copy %output_memref, %ub_buf : memref to memref -// %17 = bufferization.to_tensor %alloc_3 restrict writable : memref<256xf32> -// linalg.generic -// {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} -// ins(%output_memref, %masked_input_memref : memref, memref) -// outs(%subview_2 : memref) -// attrs = {GenericAtomicRMW = "fadd", MemSemantic = "acq_rel", -// MemSyncScope = "gpu"} { -// ^bb0(%in: f32, %in_9: f32, %out: f32): -// %25 = arith.addf %in, %in_9 : f32 -// linalg.yield %25 : f32 -// } -LogicalResult -AtomicRMWConverter::matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // If the result of AtomicRMWOp is not used, we don't need to load the old - // data stored at the ptr - auto ptr = adaptor.getPtr(); - auto val = op.getVal(); - auto loc = op.getLoc(); - - auto resType = dyn_cast(op.getResult().getType()); - if (!resType) { - return rewriter.notifyMatchFailure( - op, "atomicRMWConverter: scalar will be handled by " - "ScalarAtomicRMWCanonicalizer"); - } - - auto rmwOp = op.getAtomicRmwOp(); - if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { - return rewriter.notifyMatchFailure( - op, "AtomicRMWConverter: unsupported atomic kind for now"); - } - - // 1. Simple case where no mask is used. - auto type = dyn_cast(ptr.getType()); - if (!type) { - // Seen when implicit broadcasting is done late in a chain of - // operations. The workaround is to broadcast the pointers early in the - // address calculation. A proper fix is complicated, but at least we can - // provide a better error message. - return rewriter.notifyMatchFailure( - op, "AtomicRMWOp expects a memref, not a memref of pointers"); - } - - auto dstMemref = ptr; - // Well, linalg structure op wouldn't support mixed tensor/buffer semantics - // any more in latest LLVM(triton LLVM dependency has involed this), so we - // need to convert tensor to buffer early. - auto dstType = dstMemref.getType(); - Value inputMemref = - rewriter.create(loc, dstType, val); - - // 2. handle the mask for the atomic op - MaskState mstate; - auto mask = op.getMask(); - - // When the dsl do not pass the mask to this op like - // `tl.atomic_add(out_ptr0 + xindex, tmp2)`, it will create a constant mask - // for this op by default, which is not supported by maskAnalysis, so we - // need to handle this situation - // - // This logic come from semantic.py: - // - // if not mask: - // mask_ir = builder.get_int1(True) - // mask_ty = tl.int1 - // if ptr.type.is_block(): - // mask_ir = \ - // builder.create_splat(mask_ir, ptr.type.get_block_shapes()) - // mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) - // mask = tl.tensor(mask_ir, mask_ty) - // - // ... - // - // return ptr, val, mask - // - auto constantMask = mask.getDefiningOp(); - if (!constantMask) { - auto isContMask = mstate.parse(mask, loc, rewriter); - - if (isContMask.failed()) { - return rewriter.notifyMatchFailure( - op, "Cannot lower continuous masked loads"); - } - dstMemref = mstate.getSubview(ptr, loc, rewriter); - inputMemref = mstate.getSubview(inputMemref, loc, rewriter); - } else { - if (!isConstantMaskTrue(mask)) { - rewriter.eraseOp(op); - return success(); - } - } - - // 3. If needed, handle the return value of atomic op - // - // tt.atomicRMW op has two part of feature - // 1. load the old data at the ptr - // 2. atomically store the data on ub to the ptr - // at the same time it perform the action it has been assigned - // So we lower this op to load + atomically store - // - // The first part is not necessary when the returned value of atomic op - // is not used, it will be deleted cause it's meaningless - // Here, we preemptively determine whether it will be used - // and decide whether it is necessary to create the load process based on - // this assessment. - // - // logic of handling is copied - // TODO: decoupling the logic of load, put it in the Utils - if (!op.getResult().use_empty()) { - auto tensorType = - RankedTensorType::get(type.getShape(), type.getElementType()); - auto alloc = rewriter.create( - loc, MemRefType::get(type.getShape(), type.getElementType())); - - // For the return value, don't need to care about mask for now - // this op don't support other, so we best not fill it - rewriter.create(loc, ptr, alloc); - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); - rewriter.replaceOp(op, tensor); - } - - // create element-wise map - int64_t rank = type.getRank(); - SmallVector inputDims; - auto context = rewriter.getContext(); - - for (int i = 0; i < rank; i++) { - inputDims.push_back(getAffineDimExpr(i, context)); - } - - SmallVector indexingMaps; - // As mask has been erased for now - // the number of input must be 2 - // the input memref is also the output memref - // Thus, there are a total of three inputs and outputs. - // so here we have 3 map to create - for (int i = 0; i < 3; i++) { - indexingMaps.push_back(AffineMap::get(rank, 0, inputDims, context)); - } - - auto linalgOp = rewriter.create( - loc, /* operands */ ValueRange{dstMemref, inputMemref}, - ValueRange{dstMemref}, indexingMaps, - mlir::ConverterUtils::getNParallelLoopsAttrs(rank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { - Value opResult = createAtomicBinaryOps(nestedBuilder, nestedLoc, op, - type.getElementType(), - blockArgs[0], blockArgs[1]); - nestedBuilder.create(nestedLoc, opResult); - }); - - // "library_call" - // indicating the actual semantic of this op - // TODO: If the hardware support the MemSemantic/MemSyncScope - // We pass them down - // otherwise they need to be deleted - const StringRef genericAtomicRMW = "GenericAtomicRMW"; - const StringRef memSemantic = "MemSemantic"; - const StringRef memSyncScope = "MemSyncScope"; - linalgOp->setAttr(genericAtomicRMW, - rewriter.getStringAttr(stringifyEnum(op.getAtomicRmwOp()))); - linalgOp->setAttr(memSemantic, - rewriter.getStringAttr(stringifyEnum(op.getSem()))); - linalgOp->setAttr(memSyncScope, - rewriter.getStringAttr(stringifyEnum(op.getScope()))); - - // Mark atomic_and/or/xor specially which need software simulation in terms - // of backend restriction - if (softwareAtomicKinds.contains(op.getAtomicRmwOp())) - linalgOp->setAttr("Software", rewriter.getUnitAttr()); - - // if the result hasn't been replace by load - // we need to erase it here - if (op.getResult().use_empty()) { - rewriter.eraseOp(op); - } - return success(); -} - -LogicalResult -ScalarStoreCanonicalizer::matchAndRewrite(triton::StoreOp op, - PatternRewriter &rewriter) const { - - if (!op.getValue().getType().isIntOrIndexOrFloat()) { - return rewriter.notifyMatchFailure( - op, "ScalarStoreCanonicalizer handles scalar store scene!"); - } - - auto ptr = op.getPtr(); - auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); - auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); - auto valTy = RankedTensorType::get({(int64_t)1}, op.getValue().getType()); - auto valSplat = - rewriter.create(op.getLoc(), valTy, op.getValue()); - - auto newStoreOp = rewriter.create( - op.getLoc(), ptrSplat, valSplat, op.getCache(), op.getEvict()); - rewriter.replaceOp(op, newStoreOp); - return success(); -} - -LogicalResult -ScalarAtomicRMWCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, - PatternRewriter &rewriter) const { - - if (!op.getVal().getType().isIntOrIndexOrFloat()) { - return rewriter.notifyMatchFailure( - op, "ScalarAtomicRMWCanonicalizer handles scalar atomic rmw op scene!"); - } - - auto ptr = op.getPtr(); - auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); - auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); - auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType()); - auto valSplat = - rewriter.create(op.getLoc(), valTy, op.getVal()); - auto maskTy = RankedTensorType::get({(int64_t)1}, op.getMask().getType()); - auto maskSplat = - rewriter.create(op.getLoc(), maskTy, op.getMask()); - - auto newAtomicOp = rewriter.create( - op.getLoc(), valTy, op.getAtomicRmwOp(), ptrSplat, valSplat, maskSplat, - op.getSem(), op.getScope()); - rewriter.replaceOp(op, newAtomicOp); - return success(); -} - -// The atomic max op with float input will be devided into -// two atomic max ops with integer input -// One handles the part of the tensor greater than zero -// the other deals with the part less than zero -// It will lead to maskAnalysis failure -// So here we need to revert the procedures in semantics.py -// The triton IR is like -// -// %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x256xf32> -// %1 = tt.bitcast %value : tensor<1x256xf32> -> tensor<1x256xi32> -// %2 = tt.bitcast %ptr : tensor<1x256x!tt.ptr> -> -// tensor<1x256x!tt.ptr> %3 = arith.cmpf oge, %1, %cst_0 %4 = arith.cmpf -// olt, %1, %cst_0 %5 = arith.andi %8, %3 %6 = tt.atomic_rmw max, acq_rel, gpu, -// %2, %1, %5 : -// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> -// tensor<1x256xi32> -// %7 = arith.andi %8, %4 -// %8 = tt.atomic_rmw umin, acq_rel, gpu, %2, %1, %7 : -// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> -// tensor<1x256xi32> -// -// it's hard to handle and meaningless complicated for our device -// so we revert it to -// %0 = tt.atomic_rmw max, acq_rel, gpu, %23, %21, %8 : -// (tensor<1x256x!tt.ptr>, tensor<1x256xf32>, tensor<1x256xi1>) -> -// tensor<1x256xf32> -LogicalResult -AtomicMaxMinCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, - PatternRewriter &rewriter) const { - // Revert the op to its original form - auto ptrBitcastOp = op.getPtr().getDefiningOp(); - auto valueBitcastOp = op.getVal().getDefiningOp(); - if (!ptrBitcastOp || !valueBitcastOp) { - return failure(); - } - - // We only need to handle the op when the element type is float - auto elementType = - dyn_cast(valueBitcastOp.getSrc().getType()).getElementType(); - if (!isa(elementType)) { - return failure(); - } - - auto rmwOp = op.getAtomicRmwOp(); - // here we know that atomic UMAX/UMIN - // is created by special logic of triton right now - // so we can simply delete it - if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { - // if the return value of op is used, we can't simply erase it - if (op.getResult().use_empty()) { - rewriter.eraseOp(op); - return success(); - } - return failure(); - } - - if (rmwOp != triton::RMWOp::MAX && rmwOp != triton::RMWOp::MIN) { - return failure(); - } - - // 1. Though semantic interpreter will generate full true tensor as original - // mask if atomicrmwOp don't have it, above float devision process will also - // generate positive and negative comparison mask, which will cause to fold - // true mask. - // 2. While if atomicrmwOp has original mask, there exists andiop between - // original mask and positive/negative comparison mask - // - // Here wanna extract original mask - Value originalMask = op.getMask(); - if (auto andOp = originalMask.getDefiningOp()) - // LHS is convention in semantic interpreter - originalMask = andOp.getLhs(); - else if (auto cmpOp = originalMask.getDefiningOp()) { - if (cmpOp.getPredicate() != mlir::arith::CmpFPredicate::OGE || - !matchPattern(cmpOp.getRhs(), - /*positive float zero matcher*/ m_PosZeroFloat())) - // Here recheck frontend interpreter generation in no manual mask state - return op->emitError("Illegal mask for atomicrmwOp of float type"); - // Restore original true mask - originalMask = rewriter.create( - op->getLoc(), - /*typed attr*/ DenseElementsAttr::get( - cast(originalMask.getType()), true)); - } else - return op->emitError("Illegal mask for atomicrmwOp of float type"); - - auto originAtomicOp = rewriter.create( - op.getLoc(), valueBitcastOp.getSrc().getType(), op.getAtomicRmwOp(), - ptrBitcastOp.getSrc(), valueBitcastOp.getSrc(), originalMask, op.getSem(), - op.getScope()); - - // if the return value of op is used - // we need to handle its usage - // In semantic.py, if the atomic Max/Min with float input is used - // It will use select + bitcast to get float value - // so here we need to revert it too - // - // For example: - // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask1 : - // (tensor<32x!tt.ptr>... %1 = tt.atomic_rmw umin, acq_rel, gpu, %gm, - // %input, %mask2 : (tensor<32x!tt.ptr>... %2 = arith.select - // %devidedMask, %0, %1 : tensor<32xi1>, tensor<32xi32> %3 = tt.bitcast %2 : - // tensor<32xi32> -> tensor<32xf32> tt.store %outputMemref, %3 : - // tensor<32x!tt.ptr> - // - // will be revert to: - // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask : - // (tensor<32x!tt.ptr>... tt.store %outputMemref, %0 : - // tensor<32x!tt.ptr> - // - if (!op.getResult().use_empty()) { - for (OpOperand &use : op->getUses()) { - auto selectOp = dyn_cast(use.getOwner()); - if (!selectOp) - continue; - - for (OpOperand &selectUse : selectOp->getUses()) { - if (auto bitcastOp = - dyn_cast(selectUse.getOwner())) { - bitcastOp.getResult().replaceAllUsesWith(originAtomicOp); - } - } - } - rewriter.replaceOp(op, originAtomicOp); - } else { - rewriter.eraseOp(op); - } - - return success(); -} - -StoreConverter::StoreConverter(MLIRContext *context) - : OpConversionPattern(context) {} - -LogicalResult -StoreConverter::matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - // triton store op basic - auto mask = op.getMask(); - auto loc = op.getLoc(); - auto ptr = adaptor.getPtr(); - auto val = adaptor.getValue(); - - // 1. boundary size check - // auto boundaryCheck = op.getBoundaryCheck(); - // if (!boundaryCheck.empty()) { - // SmallVector sizes = getBoundarySizes( - // boundaryCheck, op.getPtr(), ptr, loc, rewriter); - - // auto srcSlice = getExtractSlice(val, sizes, loc, rewriter); - // auto dstSubview = getSubview(ptr, sizes, loc, rewriter); - // auto storeOp = - // rewriter.create( - // loc, srcSlice, dstSubview); - // storeOp.setWritable(true); - // rewriter.eraseOp(op); - // return success(); - // } - - // 2. Simple load with no mask - if (!mask) { - auto storeOp = rewriter.create( - loc, val, ptr); - storeOp.setWritable(true); - rewriter.eraseOp(op); - return success(); - } - - // 3. Continuous masked stores. - // Analyze the mask operand to determine at runtime the size of the data we - // are moving. - MaskState mstate; - auto isContMask = mstate.parse(mask, loc, rewriter); - - if (isContMask.failed()) { - return failure(); - } - LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); - auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); - auto dstSubview = mstate.getSubview(ptr, loc, rewriter); - auto storeOp = rewriter.create( - loc, srcSlice, dstSubview); - storeOp.setWritable(true); - rewriter.eraseOp(op); - return success(); -} -} // namespace LoadStoreConverter diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp deleted file mode 100644 index 946b781f6..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp +++ /dev/null @@ -1,543 +0,0 @@ -#include "TritonToLinalg/MaskAnalysis.h" -// #include "triton-shared/Analysis/opFoldResultutils.h" -#include "Utils/Utils.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include -#include - -#define DEBUG_TYPE "mask-analysis" - -namespace mlir { - -namespace triton { - -LogicalResult MaskState::parse(Value operand, const Location &loc, - OpBuilder &builder) { - if (isa(operand.getType())) { - return parseIntScalar(operand, loc, builder); - } - auto definingOp = operand.getDefiningOp(); - LLVM_DEBUG({ - llvm::dbgs() << "[MaskState]==> parse op\n" - << *definingOp << "\n[MaskState]<==\n"; - }); - return TypeSwitch(definingOp) - .Case( - [&](auto op) { return this->parseConstant(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseAdd(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseAnd(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseCmp(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseMakeRange(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseBroadcast(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseSplat(op, loc, builder); }) - .Case( - [&](auto op) { return this->parseExpandDims(op, loc, builder); }) - .Case( - [&](auto op) { return this->parse(op.getIn(), loc, builder); }) - .Case( - [&](auto op) { return this->parseDiv(op, loc, builder); }) - .Default([&](Operation *op) { return failure(); }); -} - -// extractSlice -tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, - const Location &loc, - OpBuilder &builder) const { - auto sourceRType = cast(source.getType()); - SmallVector strides(getRank(), builder.getIndexAttr(1)); - - auto dstRType = tensor::ExtractSliceOp::inferResultType(sourceRType, offsets, - dims, strides); - return builder.create(loc, dstRType, source, offsets, - dims, strides); -} - -tensor::InsertSliceOp MaskState::getInsertSlice(Value source, Value dest, - const Location &loc, - OpBuilder &builder) const { - auto sourceType = cast(source.getType()); - SmallVector strides(getRank(), builder.getIndexAttr(1)); - return builder.create(loc, source, dest, offsets, dims, - strides); -} - -memref::SubViewOp MaskState::getSubview(Value source, const Location &loc, - OpBuilder &builder) const { - auto sourceType = cast(source.getType()); - SmallVector strides(getRank(), builder.getIndexAttr(1)); - auto dstType = - memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return builder.create(loc, cast(dstType), - source, offsets, dims, strides); -} - -static memref::SubViewOp createSubview(Value src, const Location &loc, - OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - auto srcType = cast(src.getType()); - auto dstType = - memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return builder.create(loc, cast(dstType), src, - offsets, sizes, strides); -} - -std::pair -MaskState::getSideBySideSubviews(Value block1, Value block2, - const Location &loc, - OpBuilder &builder) const { - OpFoldResult subviewRowFull = dims[0]; - OpFoldResult subviewColFull = dims[1]; - OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); - OpFoldResult subviewCol1 = - minOpFoldResult(col1, subviewColFull, loc, builder); - OpFoldResult subviewCol2 = - subOpFoldResult(subviewColFull, subviewCol1, loc, builder); - SmallVector strides(getRank(), builder.getIndexAttr(1)); - auto sbv1 = createSubview(block1, loc, builder, offsets, - {subviewRowFull, subviewCol1}, strides); - auto sbv2 = createSubview(block2, loc, builder, offsets, - {subviewRowFull, subviewCol2}, strides); - return {sbv1, sbv2}; -} - -std::pair -MaskState::getStackedSubviews(Value block1, Value block2, const Location &loc, - OpBuilder &builder) const { - OpFoldResult subviewRowFull = dims[0]; - OpFoldResult subviewColFull = dims[1]; - OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); - OpFoldResult subviewRow1 = - minOpFoldResult(row1, subviewRowFull, loc, builder); - OpFoldResult subviewRow2 = - subOpFoldResult(subviewRowFull, subviewRow1, loc, builder); - SmallVector strides(getRank(), builder.getIndexAttr(1)); - auto sbv1 = createSubview(block1, loc, builder, offsets, - {subviewRow1, subviewColFull}, strides); - auto sbv2 = createSubview(block2, loc, builder, offsets, - {subviewRow2, subviewColFull}, strides); - return {sbv1, sbv2}; -} - -// addstatescalar -LogicalResult MaskState::addStateScalar(const MaskState &state, - const OpFoldResult scalar, - const Location &loc, - OpBuilder &builder) { - start = addOpFoldResult(state.start, scalar, loc, builder); - end = addOpFoldResult(state.end, scalar, loc, builder); - dims = state.dims; - offsets = state.offsets; - return success(); -} - -LogicalResult MaskState::addStates(const MaskState &lhsState, - const MaskState &rhsState, - const Location &loc, OpBuilder &builder) { - if (lhsState.scalar && rhsState.scalar) { - InFlightDiagnostic diag = - emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; - return failure(); - } - if (!lhsState.scalar && !rhsState.scalar) { - InFlightDiagnostic diag = - emitError(loc) - << "Unsupported scenario where neither lhs nor rhs is a scalar"; - return failure(); - } - - if (lhsState.scalar) { - return addStateScalar(rhsState, lhsState.scalar, loc, builder); - } else { - return addStateScalar(lhsState, rhsState.scalar, loc, builder); - } -} - -LogicalResult MaskState::divStateScalar(const MaskState &state, - const OpFoldResult scalar, - const Location &loc, - OpBuilder &builder) { - start = divOpFoldResult(state.start, scalar, loc, builder); - end = divOpFoldResult(state.end, scalar, loc, builder); - dims = state.dims; - offsets = state.offsets; - return success(); -} - -LogicalResult MaskState::divStates(const MaskState &lhsState, - const MaskState &rhsState, - const Location &loc, OpBuilder &builder) { - if (!lhsState.scalar && rhsState.scalar) { - if (isZeroIndex(rhsState.scalar)) { - InFlightDiagnostic diag = - emitError(loc) - << "Unsupported scenario where rhs is zero constant in divide!"; - return failure(); - } - - return divStateScalar(lhsState, rhsState.scalar, loc, builder); - } - - InFlightDiagnostic diag = emitError(loc) - << "Supported scenario where only rhs is a scalar"; - return failure(); -} - -LogicalResult MaskState::minStates(const MaskState &lhsState, - const MaskState &rhsState, - const Location &loc, OpBuilder &builder) { - if (lhsState.getRank() != rhsState.getRank()) { - InFlightDiagnostic diag = - emitError(loc) - << "Unexpected case where lhs and rhs have different ranks"; - return failure(); - } - - for (uint32_t i = 0; i < lhsState.getRank(); i++) { - auto lhsOffset = lhsState.offsets[i]; - auto rhsOffset = rhsState.offsets[i]; - auto newOffset = maxOpFoldResult(lhsOffset, rhsOffset, loc, builder); - auto lhsDim = lhsState.dims[i]; - auto rhsDim = rhsState.dims[i]; - auto lhsEnd = addOpFoldResult(lhsOffset, lhsDim, loc, builder); - auto rhsEnd = addOpFoldResult(rhsOffset, rhsDim, loc, builder); - auto newEnd = minOpFoldResult(lhsEnd, rhsEnd, loc, builder); - auto newDim = subOpFoldResult(newEnd, newOffset, loc, builder); - - offsets.push_back(newOffset); - dims.push_back(newDim); - } - return success(); -} - -// Helper func for MaskState::parse() -LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, - const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - if (isa(constOp.getValue())) { - auto attr = cast(constOp.getValue()); - auto elementType = attr.getElementType(); - assert(attr.isSplat() && isa(elementType) && - "All elements must share a single integer constant value"); - auto values = attr.getValues(); - auto value = values[0].getValue(); - auto constAttr = builder.getIndexAttr(value.getSExtValue()); - auto op = arith::ConstantOp::materialize(builder, constAttr, - builder.getIndexType(), loc); - this->scalar = op.getValue(); - } else { - auto value = cast(constOp.getValue()).getInt(); - this->scalar = builder.getIndexAttr(value); - } - return success(); -} - -// parseIntScalar -LogicalResult MaskState::parseIntScalar(Value scalar, const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - Value castOp; - - if (scalar.getType().isInteger(1)) { - castOp = builder.create(loc, builder.getIndexType(), - scalar); - } else { - castOp = - builder.create(loc, builder.getIndexType(), scalar); - } - this->scalar = castOp; - return success(); -} - -LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - MaskState lhsState; - if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) { - return failure(); - } - - MaskState rhsState; - if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) { - return failure(); - } - return this->addStates(lhsState, rhsState, loc, builder); -} - -LogicalResult MaskState::parseDiv(arith::DivSIOp divOp, const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - MaskState lhsState; - if (failed(lhsState.parse(divOp.getLhs(), loc, builder))) { - return failure(); - } - - MaskState rhsState; - if (failed(rhsState.parse(divOp.getRhs(), loc, builder))) { - return failure(); - } - return this->divStates(lhsState, rhsState, loc, builder); -} - -LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - MaskState lhsState; - if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || - !lhsState.isMask()) { - return failure(); - } - - MaskState rhsState; - if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || - !rhsState.isMask()) { - return failure(); - } - return this->minStates(lhsState, rhsState, loc, builder); -} - -LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - if (cmpOp.getPredicate() != arith::CmpIPredicate::slt && - cmpOp.getPredicate() != arith::CmpIPredicate::sge && - cmpOp.getPredicate() != arith::CmpIPredicate::eq) { - LLVM_DEBUG({ llvm::dbgs() << "Unsupported cmpi predicate\n"; }); - return failure(); - } - MaskState lhsState; - if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) { - return failure(); - } - - MaskState rhsState; - if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) { - return failure(); - } - - if (!(!lhsState.scalar && rhsState.scalar)) { - cmpOp->emitRemark("[MaskState] Unsupported cmpi scenario"); - return failure(); - } - - int32_t cmpDim = -1; - for (int32_t i = 0; i < lhsState.getRank(); i++) { - auto dimIntAttr = makeIntAttr(lhsState.dims[i]); - if (!dimIntAttr || dimIntAttr.value() != 1) { - if (cmpDim != -1) { - InFlightDiagnostic diag = emitError(loc) - << "Unsupported cmpi with more than one " - "dimension with size larger than 1"; - return failure(); - } - cmpDim = i; - } - } - - assert(cmpDim != -1 && - "Unexpected case where no dimension has size larger than 1"); - - this->offsets = lhsState.offsets; - this->dims = lhsState.dims; - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::slt: { - auto realBound = - maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); - auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); - auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); - - this->dims[cmpDim] = newDim; - break; - } - case arith::CmpIPredicate::sge: { - auto realBound = - maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); - auto newStart = minOpFoldResult(lhsState.end, realBound, loc, builder); - auto newOffset = subOpFoldResult(newStart, lhsState.start, loc, builder); - auto newDim = subOpFoldResult(lhsState.end, newStart, loc, builder); - - this->offsets[cmpDim] = newOffset; - this->dims[cmpDim] = newDim; - break; - } - case arith::CmpIPredicate::eq: { - auto newOffset = - subOpFoldResult(rhsState.scalar, lhsState.start, loc, builder); - auto newDim = builder.getIndexAttr(1); - - this->offsets[cmpDim] = newOffset; - this->dims[cmpDim] = newDim; - break; - } - default: - return failure(); - } - return success(); -} - -LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, - const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - auto shape = cast(rangeOp.getType()).getShape(); - auto start = rangeOp.getStart(); - auto end = rangeOp.getEnd(); - auto stride = (end - start + shape[0] - 1) / shape[0]; - - if (stride != 1) { - InFlightDiagnostic diag = - emitError(loc) - << "stride must be 1 for make_range whose result is used " - "as load or store masks"; - return failure(); - } - - this->start = builder.getIndexAttr(start); - this->end = builder.getIndexAttr(end); - this->dims.push_back(builder.getIndexAttr(shape[0])); - this->offsets.push_back(builder.getIndexAttr(0)); - return success(); -} - -LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, - const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - auto src = broadcastOp.getSrc(); - auto dst = broadcastOp.getResult(); - assert(isa(src.getType()) && - "input to tt.broadcast should be a tensor"); - - auto srcShape = cast(src.getType()).getShape(); - auto dstShape = cast(dst.getType()).getShape(); - assert(srcShape.size() == dstShape.size() && - "rank of source and destination should match"); - - if (failed(parse(src, loc, builder))) { - return failure(); - } - for (size_t i = 0; i < srcShape.size(); i++) { - if (srcShape[i] == dstShape[i]) - continue; - else if (srcShape[i] < dstShape[i]) { - this->dims[i] = builder.getIndexAttr(dstShape[i]); - } else { - llvm_unreachable("unexpected dimensions used in broadcast"); - } - } - return success(); -} - -LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, - const Location &loc, OpBuilder &builder) { - assert(this->isEmpty()); - - auto src = splatOp.getSrc(); - auto dst = splatOp.getResult(); - auto dstShape = cast(dst.getType()).getShape(); - - if (!isa(src.getType())) { - InFlightDiagnostic diag = - emitError(loc) - << "splat source must be an integer scalar for load/store masks"; - return failure(); - } - - if (failed(this->parse(src, loc, builder))) - return failure(); - - auto splatAsMask = [&](Operation *userOp) -> bool { - return TypeSwitch(userOp) - .Case([&](arith::AndIOp andOp) { return true; }) - .Case([&](arith::SelectOp selectOp) { - return selectOp.getCondition() == dst; - }) - .Case( - [&](triton::LoadOp loadOp) { return loadOp.getMask() == dst; }) - .Case( - [&](triton::StoreOp storeOp) { return storeOp.getMask() == dst; }) - .Default([&](Operation *op) { return false; }); - }; - - if (src.getType().isInteger(1) && !splatOp->use_empty() && - llvm::all_of(splatOp->getUsers(), splatAsMask)) { - for (auto s : dstShape) { - auto currentDim = - mulOpFoldResult(builder.getIndexAttr(s), this->scalar, loc, builder); - this->dims.push_back(currentDim); - this->offsets.push_back(builder.getIndexAttr(0)); - } - - this->scalar = nullptr; - return success(); - } - - for (auto s : dstShape) { - this->dims.push_back(builder.getIndexAttr(s)); - this->offsets.push_back(builder.getIndexAttr(0)); - } - return success(); -} - -LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, - const Location &loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) { - return failure(); - } - - auto dstShape = - cast(expandDimsOp.getResult().getType()).getShape(); - auto axis = expandDimsOp.getAxis(); - assert(dstShape[axis] == 1 && - "Expect changed dimention to be 1 in expand_dims"); - this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); - this->offsets.insert(this->offsets.begin() + axis, builder.getIndexAttr(0)); - - return success(); -} - -void MaskState::eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter) { - auto moduleOp = rawOp->getParentOfType(); - SmallVector worklist; - moduleOp->walk([&](Operation *op) { - if (isOpTriviallyDead(op)) - worklist.push_back(op); - }); - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - if (!isOpTriviallyDead(op)) - continue; - for (Value value : op->getOperands()) { - if (auto defOp = value.getDefiningOp()) - worklist.push_back(defOp); - } - LLVM_DEBUG({ - llvm::dbgs() << "[MaskState]==> inserted op: \n" - << *op << "\n[MaskState]<== is removed\n"; - }); - rewriter.eraseOp(op); - } -} - -} // namespace triton - -} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp deleted file mode 100644 index ae05f213d..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp +++ /dev/null @@ -1,1149 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation, Meta Platforms. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "TritonToLinalg/TritonOpConverter.h" -#include "TritonToLinalg/BlockPtrAnalysis.h" -#include "TritonToLinalg/MaskAnalysis.h" -#include "Utils/Utils.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/ValueRange.h" - -namespace TTOpConverters { -using namespace mlir; -using namespace triton; - -LogicalResult -AssertConverter::matchAndRewrite(triton::AssertOp op, - PatternRewriter &rewriter) const { - // TODO: update assert converter to support llvm20 - LLVM_DEBUG(llvm::dbgs() - << "we do not support assertion in kernel in llvm-20 yet \n"); - rewriter.eraseOp(op); - return success(); -} - -LogicalResult -BitcastConverter::matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto arithBitcast = rewriter.create( - op.getLoc(), op.getType(), op.getOperand()); - rewriter.replaceOp(op, arithBitcast.getResult()); - return success(); -} - -LogicalResult -TransposeConverter::matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto src = adaptor.getSrc(); - auto srcRank = cast(src.getType()).getRank(); - auto res = ConverterUtils::getTransposedValue(src, op.getLoc(), rewriter, - op.getOrder()); - rewriter.replaceOp(op, res); - return success(); -} - -LogicalResult -YieldConverter::matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return success(); -} - -LogicalResult -LoopConverter::matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - llvm::SmallDenseMap known; - BlockDataParser::IndexMapSet - levelToBlockArgIndex; // level -> set of block arg index to be replaced - - BlockDataParser::rewriteForOp(op, rewriter, levelToBlockArgIndex, 0, known); - return success(); -} - -LogicalResult -AdvanceConverter::matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - llvm::SmallDenseMap known; - BlockDataParser::rewriteAdvanceOp(op, rewriter, known); - return success(); -} - -void MakeTensorPtrConverter::populateVectorAsIndex( - SmallVector &vec, Operation::operand_range ops, - ConversionPatternRewriter &rewriter, Location loc) const { - for (auto opnd : ops) { - if (isa(opnd.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), opnd); - vec.push_back(castOp.getResult()); - } else { - assert(isa(opnd.getType())); - vec.push_back(opnd); - } - } -} - -OpFoldResult MakeTensorPtrConverter::accumulatePotentialOffsetOnBase( - triton::MakeTensorPtrOp op, Value base, OpFoldResult offset, - ConversionPatternRewriter &rewriter) const { - if (auto baseRecast = base.getDefiningOp()) { - assert(isa(op.getBase().getDefiningOp()) && - "base of MakeTensorPtrOp only comes from native ptr or AddPtrOp"); - - return addOpFoldResult(offset, baseRecast.getConstifiedMixedOffset(), - op.getLoc(), rewriter); - } - - return offset; -} - -// Design for load/store boundary_check. -memref::ReinterpretCastOp -MakeTensorPtrConverter::createRedundantOp(triton::MakeTensorPtrOp op, - ConversionPatternRewriter &rewriter, - BlockData &data) const { - auto loc = op.getLoc(); - // to do boundary_check in tt.load, we need to keep the parent tensor's - // shape info in the IR. - // use the parent tensor's shape to create a cast - auto resultSizes = data.getSizes(); - data.getSizesRef().clear(); - populateVectorAsIndex(data.getSizesRef(), op.getShape(), rewriter, loc); - SmallVector staticShapes; - SmallVector dynamicShapes; - dispatchIndexOpFoldResults(data.getSizesRef(), dynamicShapes, staticShapes); - auto castOp = data.createCastOp(staticShapes, loc, rewriter); - // restore sizes - data.getSizesRef().clear(); - for (auto &s : resultSizes) { - data.getSizesRef().push_back(s); - } - return castOp; -} - -// ToDo: -// 1. Refactor MakeTensorPtrConverter and AdvanceConverter with -// memref::ReinterpretCastOp and memref::SubViewOp. -// Use recast to describe full shape of tensor, and use subview to represent -// current block tensor. -// 2. Support boundary_check & padding_option for load/store, while current -// method with redundant recast is just enabled in load and drops padding_option -LogicalResult MakeTensorPtrConverter::matchAndRewrite( - triton::MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - BlockData data; - - auto orderSize = op.getOrder().size(); - if (orderSize > 1) { - // Declaration of llvm::ArrayRef::slice(n, m) - // - Chop off the first N elements of the array, and keep M elements - // in the array. - // Take care that 'm' means chunk length - for (auto [first, second] : - llvm::zip(op.getOrder().slice(0, orderSize - 1), - op.getOrder().slice(1, orderSize - 1))) { - if (first != second + 1) { - op->emitError("Currently only support default order on block pointers"); - return failure(); - } - } - } - - // Handle base is defined by tt.bitcast - llvm::SmallDenseMap known; - BlockDataParser::parse(op.getBase(), data, loc, rewriter, known); - if (data.hasResElemTy()) { - auto memrefType = dyn_cast(data.getSourceRef().getType()) - .cloneWith(std::nullopt, data.getResElemTyRef()); - UnrealizedConversionCastOp castOp = - rewriter.create(loc, memrefType, - data.getSourceRef()); - data.setSource(castOp.getOutputs()[0]); - } else { - data.setSource(rewriter.getRemappedValue(op.getBase())); - } - - populateVectorAsIndex(data.getOffsetsRef(), op.getOffsets(), rewriter, loc); - populateVectorAsIndex(data.getStridesRef(), op.getStrides(), rewriter, loc); - - SmallVector newOffsets; - for (auto [offset, stride] : - llvm::zip(data.getOffsetsRef(), data.getStridesRef())) - newOffsets.push_back(mulOpFoldResult(offset, stride, loc, rewriter)); - - // 1. Consider that current base ptr may comes from `triton::AddPtrOp`, - // which have been converted to `memref::ReinterpretCastOp` with 1D - // shape([1,]) by `AddPtrConverter`. - // 2. While here would also convert `triton::MakeTensorPtrOp` to - // `memref::ReinterpretCastOp`, it will create use-def on double recast - // which means offset&size&stride info of first one will be dropped in terms - // of memref recast op specification. - // - // Conclusion with above two: - // Base of MakeTensorPtrOp has been seen as origin base, so it should - // reserve offset of first recast if it exists. - // Here extract the offset of first recastr and add it to highest dimension - newOffsets.front() = accumulatePotentialOffsetOnBase( - op, adaptor.getBase(), newOffsets.front(), rewriter); - - data.getOffsetsRef().clear(); - - for (auto offset : newOffsets) { - data.getOffsetsRef().push_back(offset); - } - - ArrayRef resultShape; - auto pointerType = cast(op.getResult().getType()); - if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { - resultShape = shapedType.getShape(); - for (auto dim_size : resultShape) { - data.getSizesRef().push_back( - IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); - } - } else { - // scalar pointer, should produce a one dimensional memref - SmallVector scalarShape(1, 1); - resultShape = scalarShape; - assert(data.getRank() == 1); - } - - // special handling for davinci - // create redundant reinterpret_cast op for record shape info - auto redundantOp = createRedundantOp(op, rewriter, data); - redundantOp->setAttr("tensor_ptr_attr", rewriter.getStringAttr("shape")); - - // create reinterpret_cast op for the target block - data.setSource(redundantOp.getResult()); - auto castOp = data.createCastOp(resultShape, loc, rewriter); - rewriter.replaceOp(op, castOp.getResult()); - return success(); -} - -LogicalResult PreciseDivConverter::matchAndRewrite( - triton::PreciseDivFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value opa = op.getX(); - Value opb = op.getY(); - auto loc = op.getLoc(); - - auto resType = dyn_cast(op.getResult().getType()); - auto divOp = rewriter.create(loc, resType, opa, opb); - - rewriter.replaceOp(op, divOp); - return success(); -} - -/* - * Rewrite arith.select with contiguouse mask to - * tensor.extract_slice/insert_slice. - */ - -LogicalResult -SelectConverter::matchAndRewrite(arith::SelectOp op, - PatternRewriter &rewriter) const { - auto loc = op.getLoc(); - - // 0. Shortcut for scalars - auto type = dyn_cast(op.getResult().getType()); - if (!type) { - // do nothing non-tensor select - return failure(); - } - auto mask = op.getCondition(); - if (!isa(mask.getType())) { - // do nothing for scalar mask - return failure(); - } - - // 1. Check for continuous masked loads. - // Analyze the mask operand to determine at runtime the size of the data we - // are moving. - MaskState mstate; - auto isContMask = mstate.parse(mask, loc, rewriter); - - if (isContMask.failed()) { - mstate.eraseInsertedOps(op, rewriter); - return rewriter.notifyMatchFailure( - op, "Cannot lower continuous masked selects"); - } - - // 2. Slice out the masked part of true tensor - auto trueTensor = op.getTrueValue(); - auto trueSlice = mstate.getExtractSlice(trueTensor, loc, rewriter); - - // 3. Insert out the sliced true tensor into false tensor - auto falseTensor = op.getFalseValue(); - auto result = mstate.getInsertSlice(trueSlice, falseTensor, loc, rewriter); - - rewriter.replaceOp(op, result); - return success(); -} - -/* - * Move tt.bitcast to a previous location if tt.bitcast is not directly applied - * on function arguments - */ -LogicalResult -BitcastCanonicalizer::matchAndRewrite(triton::BitcastOp bitcastOp, - PatternRewriter &rewriter) const { - Value castSrc = bitcastOp.getSrc(); - Value castRes = bitcastOp.getResult(); - Type castSrcTy = castSrc.getType(); - Type castSrcPtrTy = isa(castSrcTy) - ? cast(castSrcTy).getElementType() - : castSrcTy; - if (!isa(castSrcPtrTy)) - return failure(); - - auto origBitwidth = getPointeeBitWidth(castSrc.getType()); - auto castBitwidth = getPointeeBitWidth(castRes.getType()); - - if (origBitwidth == 1) - origBitwidth = 8; - if (castBitwidth == 1) - castBitwidth = 8; - if (origBitwidth != castBitwidth) { - bitcastOp.emitError() << "Casting pointers with unmatched bitwidth!\n"; - return failure(); - } - - Operation *beforeCastOp = castSrc.getDefiningOp(); - if (beforeCastOp == nullptr) { - return failure(); - } - - auto newRes = - TypeSwitch>(beforeCastOp) - // before: addptr - bitcast - load/store - // after: bitcast - addptr - load/store - .Case([&](triton::AddPtrOp addptrOp) { - auto newCastOp = rewriter.create( - bitcastOp.getLoc(), castRes.getType(), addptrOp.getPtr()); - return rewriter.create( - bitcastOp.getLoc(), castRes.getType(), newCastOp.getResult(), - addptrOp.getOffset()); - }) - .Case([&](triton::SplatOp splatOp) { - Type newCastSrcTy = - cast(castRes.getType()).getElementType(); - - Value splatSrc = splatOp.getSrc(); - Type splatSrcTy = splatSrc.getType(); - if (auto splatSrcTensorTy = dyn_cast(splatSrcTy)) - newCastSrcTy = - splatSrcTensorTy.cloneWith(std::nullopt, newCastSrcTy); - auto newCastOp = rewriter.create( - bitcastOp.getLoc(), newCastSrcTy, splatSrc); - return rewriter.create( - bitcastOp.getLoc(), castRes.getType(), newCastOp); - }) - // before: bitcast - bitcast - // after(fusion optimization): bitcast - .Case([&](triton::BitcastOp prevCastOp) { - return rewriter.create( - bitcastOp.getLoc(), castRes.getType(), prevCastOp.getSrc()); - }) - .Default([&](Operation *op) { - return rewriter.notifyMatchFailure(bitcastOp, - "Unknown bitcast pattern"); - }); - if (succeeded(newRes)) { - rewriter.replaceOp(bitcastOp, newRes.value()); - if (beforeCastOp->use_empty()) { - rewriter.eraseOp(beforeCastOp); - } - return success(); - } - return failure(); -} - -LogicalResult DenseConstantConverter::matchAndRewrite( - arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto denseAttr = cast(op.getValue()); - auto loc = op.getLoc(); - auto constSplatOp = arith::ConstantOp::materialize( - rewriter, denseAttr.getSplatValue(), - denseAttr.getElementType(), loc); - auto emptyOp = rewriter.create( - loc, cast(op.getResult().getType()).getShape(), - denseAttr.getElementType()); - - rewriter.replaceOpWithNewOp(op, ValueRange{constSplatOp}, - ValueRange{emptyOp}); - - return success(); -} - -LogicalResult -MakeRangeConverter::matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto type = cast(op.getResult().getType()); - auto shape = type.getShape(); - auto elementType = type.getElementType(); - auto context = op.getContext(); - - assert(type.getShape().size() == 1 && - isa(type.getElementType()) && - type.getElementType().getIntOrFloatBitWidth() == 32 && - "make range can only return 1D int32 tensor"); - - SmallVector indexingMaps{AffineMap::get( - /* dimCount */ 1, /* symbolCount */ 0, - {mlir::getAffineDimExpr(0, context)}, context)}; - - auto init = rewriter.create(loc, shape, elementType); - - auto nestedBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange blockArgs) { - Value index = nestedBuilder.create(loc, 0); - Value res = nestedBuilder.create( - loc, type.getElementType(), index); - nestedBuilder.create(loc, res); - }; - - auto linalgOp = rewriter.create( - loc, op->getResultTypes(), /* operands */ ValueRange{}, ValueRange{init}, - indexingMaps, ConverterUtils::getNParallelLoopsAttrs(1), nestedBody); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); -} - -LogicalResult -SplatConverter::matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto init = rewriter.create(loc, op.getType().getShape(), - op.getType().getElementType()); - rewriter.replaceOpWithNewOp(op, ValueRange{adaptor.getSrc()}, - ValueRange{init}); - return success(); -} - -LogicalResult -ReshapeConverter::matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto src = op.getSrc(); - auto dst = op.getResult(); - Value shape = rewriter.create( - loc, - rewriter.getI64TensorAttr(cast(dst.getType()).getShape())); - auto reshapeOp = - rewriter.create(loc, dst.getType(), src, shape); - rewriter.replaceOp(op, reshapeOp.getResult()); - return success(); -} - -LogicalResult ExpandDimsConverter::matchAndRewrite( - triton::ExpandDimsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto src = op.getSrc(); - auto resShape = cast(op.getResult().getType()).getShape(); - auto axis = op.getAxis(); - - SmallVector reassociation; - - auto src_last_dim = resShape.size() - 2; - auto map_func = [&](unsigned i) -> ReassociationIndices { - if (i < axis) { - return i == src_last_dim ? ReassociationIndices{i, i + 1} - : ReassociationIndices{i}; - } - return i == axis ? ReassociationIndices{i, i + 1} - : ReassociationIndices{i + 1}; - }; - - reassociation = llvm::to_vector( - llvm::map_range(llvm::seq(0, src_last_dim + 1), map_func)); - - auto expandShapeOp = rewriter.create( - op.getLoc(), op.getResult().getType(), src, reassociation); - rewriter.replaceOp(op, expandShapeOp.getResult()); - return success(); -} - -LogicalResult -ClampFConverter::matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto input = adaptor.getX(); - auto min_para = adaptor.getMin(); - auto max_para = adaptor.getMax(); - auto propagateNan_para = adaptor.getPropagateNan(); - - if (auto input_type = dyn_cast(input.getType())) { - if (isa(min_para.getType())) { - auto minEmptyTensor = rewriter.create( - loc, input_type.getShape(), input_type.getElementType()); - min_para = rewriter - .create(loc, ValueRange{min_para}, - ValueRange{minEmptyTensor}) - .result(); - } - if (isa(max_para.getType())) { - auto maxEmptyTensor = rewriter.create( - loc, input_type.getShape(), input_type.getElementType()); - max_para = rewriter - .create(loc, ValueRange{max_para}, - ValueRange{maxEmptyTensor}) - .result(); - } - } - - if (propagateNan_para == PropagateNan::NONE) { - auto minOp = rewriter.create(loc, input, max_para); - auto maxOp = rewriter.create(loc, min_para, minOp); - rewriter.replaceOp(op, ValueRange{maxOp}); - } else if (propagateNan_para == PropagateNan::ALL) { - auto minOp = rewriter.create(loc, input, max_para); - auto maxOp = rewriter.create(loc, min_para, minOp); - rewriter.replaceOp(op, ValueRange{maxOp}); - } else { - return failure(); - } - - return success(); -} - -// Here convert tt.broadcast to linalg.broadcast -// -// before -// %out = tt.broadcast %in : tensor<1x4x8xf32> -> tensor<128x4x8xf32> -// -// after -// %collpased = tensor.collapse_shape %in [[0, 1], [2]] : -// tensor<1x4x8xf32> into tensor<4x8xf32> -// %out = linalg.broadcast ins(%collpased : tensor<4x8xf32>) -// outs(%empty : tensor<128x4x8xf32>) dimensions = [0] -LogicalResult -BroadcastConverter::matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(op->getNumResults() == 1 && "BroadcastOp assumes single result"); - - RankedTensorType sourceType = - cast(adaptor.getSrc().getType()); - RankedTensorType resultType = cast(op.getType()); - auto elementType = resultType.getElementType(); - size_t resultRank = resultType.getRank(); - auto loc = op.getLoc(); - - auto initEmpty = - rewriter.create(loc, resultType.getShape(), elementType); - - SmallVector broadcastDims = - ConverterUtils::getBroadcastDims(sourceType, resultType); - SmallVector unbroadcastDims = - ConverterUtils::getUnbroadcastDims(sourceType, resultType); - - SmallVector collapseReassociationIndices; - auto collapseReassociationIndicesOptional = - getReassociationIndicesForCollapse(sourceType.getShape(), - unbroadcastDims); - if (!collapseReassociationIndicesOptional.has_value()) { - return rewriter.notifyMatchFailure( - op, "Failure with getReassociationIndicesForCollapse call"); - } - collapseReassociationIndices = collapseReassociationIndicesOptional.value(); - - RankedTensorType collapseResultType = - RankedTensorType::get(unbroadcastDims, sourceType.getElementType()); - - auto collpasedOp = rewriter.create( - loc, collapseResultType, adaptor.getSrc(), collapseReassociationIndices); - - auto broadcastOp = rewriter.create( - loc, collpasedOp, initEmpty, - rewriter.getDenseI64ArrayAttr(broadcastDims)); - - rewriter.replaceOp(op, broadcastOp.getResults()); - return success(); -} - -// Reduce Converter -llvm::SmallVector -ReduceConverter::getRedOps(triton::ReduceOp redOp) const { - auto reduceBlock = redOp.getBody(); - return llvm::map_to_vector(reduceBlock->without_terminator(), - [](Operation &op) { return &op; }); -} - -bool ReduceConverter::isReductionOpSupported(Operation *redOp) const { - return isa(redOp); -} - -arith::ConstantOp -ReduceConverter::getRedBaseConstOp(ConversionPatternRewriter &rewriter, - Operation *redOp, Type constantType) const { - const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); - - auto attr = llvm::TypeSwitch(redOp) - .Case([&](arith::AddFOp) { - return rewriter.getFloatAttr(constantType, 0.f); - }) - .Case([&](arith::AddIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](auto) { - return rewriter.getFloatAttr( - constantType, -std::numeric_limits::infinity()); - }) - .Case([&](auto) { - return rewriter.getFloatAttr( - constantType, std::numeric_limits::infinity()); - }) - .Case([&](arith::MinSIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::maxIntN(bitWidth)); - }) - .Case([&](arith::MinUIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::maxUIntN(bitWidth)); - }) - .Case([&](arith::MaxSIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::minIntN(bitWidth)); - }) - .Case([&](arith::MaxUIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](arith::OrIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](arith::AndIOp) { - return rewriter.getIntegerAttr(constantType, 1); - }) - .Case([&](arith::XOrIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not supported yet"); - return nullptr; - }); - - return rewriter.create(redOp->getLoc(), constantType, - attr); -} - -bool ReduceConverter::requiresF32Conversion(const Type elemType, - Operation *redOp) const { - return isa(elemType) && - elemType.getIntOrFloatBitWidth() < - Float32Type::get(elemType.getContext()).getWidth() && - isa(redOp); -} - -Value ReduceConverter::getRedElement( - Value lhs, Value rhs, const Location loc, Operation *redOp, OpBuilder &b, - const bool convertLhsToF32Precision) const { - return llvm::TypeSwitch(redOp) - .Case([&](arith::AddFOp) { - if (convertLhsToF32Precision) { - lhs = b.create(loc, Float32Type::get(b.getContext()), - lhs); - } - return b.create(loc, lhs, rhs); - }) - .Case( - [&](auto redOp) { return b.create(loc, lhs, rhs); }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not yet supported"); - return nullptr; - }); -} - -LogicalResult ReduceConverter::convertToLinalgReduce( - triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto source = adaptor.getOperands().front(); - auto sourceType = cast(source.getType()); - auto elemType = sourceType.getElementType(); - auto resType = op.getResult().front().getType(); - auto loc = op.getLoc(); - auto reductionOps = getRedOps(op); - - // Reduction of arbitrary operations isn't supported because using the first - // element across the reduction dimension requires us to iterate over a - // subview that skips over each first element. - if (reductionOps.size() != 1 || - !isReductionOpSupported(reductionOps.front())) { - return rewriter.notifyMatchFailure( - op, "Only support lowering reduction with body " - "containing 1 max(i/f) or addf."); - } - - auto rop = reductionOps.front(); - auto axis = op.getAxis(); - auto isVectorReduce = sourceType.getRank() == 1; - - auto constantType = elemType; - - auto accBaseConstOp = getRedBaseConstOp(rewriter, rop, constantType); - Value initTensor; - - if (isVectorReduce) { - auto holder = rewriter.create( - loc, RankedTensorType::get({}, constantType), ValueRange{}); - initTensor = rewriter - .create(loc, accBaseConstOp.getResult(), - holder.getResult()) - .getResult(0); - } else { - Value init = rewriter.create( - loc, cast(resType).getShape(), constantType); - initTensor = - rewriter.create(loc, accBaseConstOp.getResult(), init) - .getResult(0); - } - - Value finalResult = - rewriter - .create( - loc, ValueRange{source}, ValueRange{initTensor}, - SmallVector{axis}, - [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { - assert(inputs.size() == 2); - Value result = getRedElement(inputs[0], inputs[1], loc, rop, - opBuilder, false); - opBuilder.create(loc, result); - }) - .getResult(0); - - if (sourceType.getRank() == 1) { - finalResult = - rewriter.create(loc, constantType, finalResult); - } - - rewriter.replaceOp(op, finalResult); - return success(); -} - -LogicalResult ReduceConverter::convertToLinalgReduceExtended( - ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto elemTypes = op.getElementTypes(); - - auto valueResultType = dyn_cast(op.getType(0)); - const auto isScalarReduce = valueResultType == nullptr; - - SmallVector outputs; - for (auto i = 0; i < op.getResult().size() && i < elemTypes.size(); i++) { - auto result = dyn_cast(op.getType(i)); - SmallVector resultShape{ - isScalarReduce ? SmallVector{} - : SmallVector(result.getShape())}; - outputs.push_back( - rewriter.create(loc, resultShape, elemTypes[i])); - } - - auto linalgOp = rewriter.create( - loc, adaptor.getOperands(), outputs, - SmallVector{adaptor.getAxis()}, - [&](OpBuilder &b, Location loc, ValueRange inputs) { - auto tritonReduceBlock = op.getBody(); - IRMapping mapping; - mapping.map(tritonReduceBlock->getArguments(), inputs); - - for (auto &op : tritonReduceBlock->without_terminator()) { - b.clone(op, mapping); - } - - auto tritonYield = tritonReduceBlock->getTerminator(); - auto results = - llvm::map_to_vector(tritonYield->getOperands(), - [&](Value val) { return mapping.lookup(val); }); - b.create(loc, results); - }); - - if (failed(addReduceWithIndexAttrIfNeeded(rewriter, linalgOp))) { - return rewriter.notifyMatchFailure(op, "meaningless reduce operation"); - } - - if (isScalarReduce) { - SmallVector reduceResults; - for (auto i = 0; i < linalgOp.getResults().size() && i < elemTypes.size(); - i++) { - reduceResults.push_back(rewriter.create( - loc, elemTypes[i], linalgOp.getResults()[i], ValueRange{})); - } - rewriter.replaceOp(op, reduceResults); - } else { - rewriter.replaceOp(op, linalgOp); - } - return success(); -} - -LogicalResult -ReduceConverter::matchAndRewrite(triton::ReduceOp op, - typename triton::ReduceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto sourceType = - cast(adaptor.getOperands().front().getType()); - assert(sourceType.hasRank() && "Expected input is " - "ranked"); - - int64_t axis = op.getAxis(); - assert(axis >= 0 && axis < sourceType.getRank() && - "Expected reduction " - "axis is within " - "operand's rank"); - - auto reductionOps = getRedOps(op); - if (reductionOps.size() == 1) { - return convertToLinalgReduce(op, adaptor, rewriter); - } - return convertToLinalgReduceExtended(op, adaptor, rewriter); -} - -LogicalResult ExternElementwiseClOpConverter::matchAndRewrite( - triton::ExternElementwiseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - if (!op.getPure()) { - op->emitWarning() << "impure elementwise op!"; - return failure(); - } - if (op.getSymbol().contains("__hmf_")) { - // 1. get or create the declaration of external elementwise function - Type dstTy = op.getResult().getType(); - bool isDstScalar = !isa(dstTy); - Type dstElemTy = - isDstScalar ? dstTy : cast(dstTy).getElementType(); - SmallVector srcElemTys; - SmallVector srcs; - for (auto src : op.getSrcs()) { - if (!isa(src.getType())) { - src = rewriter.create( - op.getLoc(), RankedTensorType::get({(int64_t)1}, src.getType()), - src); - } - srcs.push_back(src); - srcElemTys.push_back( - cast(src.getType()).getElementType()); - } - FunctionType elemFuncType = - FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy}); - auto mod = SymbolTable::getNearestSymbolTable(op); - auto extFunc = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(mod, op.getSymbol())); - if (!extFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&mod->getRegion(0).front()); - extFunc = rewriter.create(rewriter.getUnknownLoc(), - op.getSymbol(), elemFuncType); - extFunc.setPrivate(); - extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), - UnitAttr::get(rewriter.getContext())); - } - assert(isa( - SymbolTable::lookupSymbolIn(mod, op.getSymbol()))); - // 2. prepare the output tensor - Value output; - if (isDstScalar) { - dstTy = RankedTensorType::get({(int64_t)1}, dstElemTy); - } - bool found = false; - for (Value v : srcs) { - if (v.getType() == dstTy) { - found = true; - output = v; - break; - } - } - if (!found) { - output = rewriter.create( - op.getLoc(), cast(dstTy).getShape(), dstElemTy); - } - // 3. create the linalg.map op - auto mapOp = rewriter.create( - loc, - /*inputs=*/srcs, - /*init=*/output, - /*bodyBuilder=*/ - [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { - auto elemOp = builder.create(loc, - /*name=*/op.getSymbol(), - /*resultType=*/dstElemTy, - /*operands=*/regionArgs); - builder.create(loc, elemOp->getResults()); - }); - if (isDstScalar) { - // need to convert tensor back to scalar - auto indexType = rewriter.getIndexType(); - Value zeroConstant = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, 0)); - auto extractOp = rewriter.create( - loc, mapOp.getResults()[0], zeroConstant); - rewriter.replaceOp(op, extractOp); - } else { - rewriter.replaceOp(op, mapOp); - } - return success(); - } - return failure(); -} - -LogicalResult UnrealizedCastConverter::matchAndRewrite( - UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.eraseOp(op); - return success(); -} - -LogicalResult -JoinConverter::matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value opa = op.getLhs(); - Value opb = op.getRhs(); - auto loc = op.getLoc(); - - auto resType = dyn_cast(op.getResult().getType()); - Value emptyOp = rewriter.create(loc, resType.getShape(), - resType.getElementType()); - - auto shape = dyn_cast(opa.getType()).getShape(); - auto sizes = llvm::map_to_vector(shape, [&](int64_t t) { - return OpFoldResult(rewriter.getI64IntegerAttr(t)); - }); - sizes.push_back(rewriter.getI64IntegerAttr(1)); - - int64_t rank = resType.getRank(); - - // Set last dimension stride to 2 in layout - // As last dimension size is always 1, last dimension stride here could be - // either 1 or 2, while stride `2` could carry interleave trait and it's - // convenient for next lower. - SmallVector strides(rank, rewriter.getIndexAttr(1)); - strides.back() = rewriter.getIndexAttr(2); - - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - - auto insert0 = rewriter.create( - loc, opa, emptyOp, offsets, sizes, strides); - - offsets.back() = rewriter.getIndexAttr(1); - auto insert1 = rewriter.create( - loc, opb, insert0, offsets, sizes, strides); - rewriter.replaceOp(op, insert1); - return success(); -} - -LogicalResult -CatConverter::matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value opa = op.getLhs(); - Value opb = op.getRhs(); - auto loc = op.getLoc(); - - auto resType = dyn_cast(op.getResult().getType()); - auto emptyOp = rewriter.create(loc, resType.getShape(), - resType.getElementType()); - - auto rank = resType.getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - - auto inputType = dyn_cast(opa.getType()); - - SmallVector sizes = - llvm::map_to_vector(inputType.getShape(), [&](int64_t t) { - return OpFoldResult(rewriter.getI64IntegerAttr(t)); - }); - - auto insert0 = rewriter.create( - loc, opa, emptyOp, offsets, sizes, strides); - - offsets[0] = - rewriter.getIndexAttr(inputType.getRank() ? inputType.getShape()[0] : 1); - auto insert1 = rewriter.create( - loc, opb, insert0, offsets, sizes, strides); - - rewriter.replaceOp(op, insert1); - return success(); -} - -/// @brief Convert tt.gather to func.call. BiShengIR captures the func -/// with assumed semantics. -/// @param op The `triton::GatherOp` operation to be rewritten. -/// @param adaptor An adaptor for the operation's operands. -/// @param rewriter A pattern rewriter used to modify the IR. -/// @return A `LogicalResult` indicating whether the rewrite was successful. -LogicalResult -GatherConverter::matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - Value src = adaptor.getSrc(); - Value idx = adaptor.getIndices(); - Value res = op.getResult(); - auto gatherAxis = op.getAxis(); - - auto moduleOp = op->getParentOfType(); - rewriter.setInsertionPoint(moduleOp.getBody(), - std::prev(moduleOp.getBody()->end())); - - llvm::SmallString<128> funcName = gatherFuncNameBase; - int uniqueId = 0; - while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) { - funcName += "_" + std::to_string(uniqueId++); - } - - auto resTy = res.getType(); - auto libFnType = rewriter.getFunctionType( - {src.getType(), idx.getType(), rewriter.getI32Type()}, {resTy}); - auto funcOp = rewriter.create(loc, funcName.str(), libFnType); - SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); - - rewriter.setInsertionPoint(op); - Value axis = rewriter.create(loc, gatherAxis, 32); - auto callOp = rewriter.create(loc, funcOp.getSymNameAttr(), - TypeRange({resTy}), - ValueRange({src, idx, axis})); - - rewriter.replaceOp(op, callOp); - - return success(); -} - -LogicalResult -SplitConverter::matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = op.getSrc(); - auto loc = op.getLoc(); - auto inputType = cast(input.getType()); - - int64_t rank = inputType.getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - // Similar to JoinConverter, here adjust last dimension stride - SmallVector strides(rank, rewriter.getIndexAttr(1)); - strides.back() = rewriter.getIndexAttr(2); - - auto outType = dyn_cast(op.getOutLHS().getType()); - auto sizes = llvm::map_to_vector(outType.getShape(), [&](int64_t t) { - return OpFoldResult(rewriter.getIndexAttr(t)); - }); - sizes.push_back(rewriter.getIndexAttr(1)); - - auto slice0 = rewriter.create( - loc, outType, input, offsets, sizes, strides); - - offsets.back() = rewriter.getIndexAttr(1); - auto slice1 = rewriter.create( - loc, outType, input, offsets, sizes, strides); - - SmallVector slices = {slice0.getResult(), slice1.getResult()}; - rewriter.replaceOp(op, ValueRange(slices)); - return success(); -} - -/* -the element-wise most significant N bits of the 2N-bit product of x and y -%x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi32> -*/ -LogicalResult TritonMulhiuiConverter::matchAndRewrite( - triton::MulhiUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - Value opl = op.getX(); - Value opr = op.getY(); - Value res = op.getResult(); - auto newMulOp = rewriter.create( - loc, res.getType(), res.getType(), opl, opr); - // triton only need the high value - rewriter.replaceOp(op, ValueRange{newMulOp.getHigh()}); - return success(); -} - -LogicalResult TritonPreciseSqrtConverter::matchAndRewrite( - triton::PreciseSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return success(); -} - -LogicalResult DevicePrintConverter::matchAndRewrite( - triton::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto moduleOp = op->getParentOfType(); - rewriter.setInsertionPoint(moduleOp.getBody(), - std::prev(moduleOp.getBody()->end())); - SmallVector inputTypes; - for (auto arg : op.getArgs()) { - inputTypes.push_back(arg.getType()); - } - auto libFnType = rewriter.getFunctionType(inputTypes, {}); - auto funcOp = - rewriter.create(op.getLoc(), printFuncNameBase, libFnType); - SymbolTable symTab(moduleOp); - auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab}); - if (failed(maybePrintFuncNameAttr)) { - return op->emitError( - "failed to create a unique func name for device_print"); - } - SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); - auto prefixAttr = op.getPrefixAttr(); - funcOp->setAttr(prefixAttrName, prefixAttr); - auto hexAttr = op.getHexAttr(); - funcOp->setAttr(hexAttrName, hexAttr); - - rewriter.setInsertionPoint(op); - rewriter.create(op.getLoc(), funcOp, op.getArgs()); - - rewriter.eraseOp(op); - return success(); -} - -LogicalResult -MatmulConverter::matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto opa = adaptor.getA(); - auto opb = adaptor.getB(); - auto opc = adaptor.getC(); - auto dstType = cast(op.getType()); - auto inputPrec = op.getInputPrecision(); - - if (dstType.getRank() == 2) { - auto matmulOp = rewriter.replaceOpWithNewOp( - op, ValueRange{opa, opb}, ValueRange{opc}); - matmulOp->setAttr( - "input_precison", - rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); - } else if (dstType.getRank() == 3) { - auto matmulOp = rewriter.replaceOpWithNewOp( - op, ValueRange{opa, opb}, ValueRange{opc}); - matmulOp->setAttr( - "input_precison", - rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); - } else { - llvm_unreachable("Datatype of DotOp operands could only be 2D or 3D"); - } - return success(); -} -} // namespace TTOpConverters diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp deleted file mode 100644 index 9f7959074..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp +++ /dev/null @@ -1,544 +0,0 @@ -#include "TritonToLinalg/TritonToLinalgPass.h" -#include "TritonToLinalg/ArgMinMaxConverter.h" -#include "TritonToLinalg/FunctionConverter.h" -#include "TritonToLinalg/LoadStoreConverter.h" -#include "TritonToLinalg/TritonOpConverter.h" -#include "TritonToLinalg/UseAnalysis.h" -#include "Utils/InterleaveOptimization.h" -#include "Utils/Utils.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" - -#include -#include - -#define DEBUG_TYPE "triton-to-linalg" - -using namespace mlir; -using namespace triton; - -TritonTypeConverter::TritonTypeConverter() { - addConversion([](Type type) { return type; }); - - addConversion([](triton::PointerType ptrType) { - return MemRefType::get({ShapedType::kDynamic}, ptrType.getPointeeType()); - }); - - addConversion([](TensorType tensorType) -> Type { - auto elemType = tensorType.getElementType(); - if (auto ptrType = dyn_cast(elemType)) { - elemType = ptrType.getPointeeType(); - } - return MemRefType::get(tensorType.getShape(), elemType); - }); -} - -void TritonToLinalgPass::addProgramInfo(triton::FuncOp func, - bool globalKernel) { - OpBuilder b(func); - - auto origFuncType = func.getFunctionType(); - auto origInputTypes = origFuncType.getInputs(); - SmallVector newInputTypes(origInputTypes); - newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); - - auto newFuncType = - b.getFunctionType(newInputTypes, origFuncType.getResults()); - - func.setFunctionType(newFuncType); - - // 如果需要,给参数新增属性 - if (func.getAllArgAttrs()) { - SmallVector newArgAttrs; - func.getAllArgAttrs(newArgAttrs); - newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); - func.setAllArgAttrs(newArgAttrs); - } - - // 添加对应参数到函数体中 - for (unsigned i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { - func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); - } - - if (globalKernel) { - func->setAttr(globalKernelAttr, b.getStringAttr("")); - } else { - func->setAttr(globalKernelAttr, b.getStringAttr("local")); - } -} - -void TritonToLinalgPass::convertTTFunc(triton::FuncOp func, - const bool existDot) { - OpBuilder builder(func); - - auto name = func.getName(); - auto type = func.getFunctionType(); - - SmallVector argAttrs, resAttrs; - func.getAllArgAttrs(argAttrs); - func.getAllResultAttrs(resAttrs); - - // bit-casted tt.ptr的特殊处理 - SmallVector inputTypes{type.getInputs()}; - SmallVector retTypes{type.getResults()}; - if (func.getSymVisibility() == "public" && !func.isDeclaration()) { - for (size_t i = 0; i < func.getNumArguments(); ++i) { - auto arg = func.getArgument(i); - if (!isa(arg.getType())) { - continue; - } - // FIXME: Why arg.getUsers() cannot return the user inside scf.for? - llvm::SmallVector arg_users; - func.walk([&](Operation *op) { - if (op->use_empty()) { - return WalkResult::advance(); - } - for (auto operand : op->getOperands()) { - if (operand == arg) { - arg_users.push_back(op); - } - } - return WalkResult::advance(); - }); - - bool arg_use_empty = arg_users.size() == 0; - if (!arg_use_empty) { - LLVM_DEBUG({ - auto &os = llvm::dbgs(); - os << arg << " has users:\n"; - int cnt = 0; - for (auto it : arg_users) { - os << "users[" << cnt++ << "] = " << *it; - } - }); - if (llvm::all_of(arg_users, [](Operation *userOp) { - return isa(userOp); - })) { - auto castOp = cast(*arg_users.begin()); - if (castOp.getInputs().size() == 1 && - castOp.getOutputs().size() == 1) { - arg.setType(castOp.getOutputs()[0].getType()); - inputTypes[i] = arg.getType(); - } - } - } else { - // Process unused bool ptr type specially, which guarantees bool pointer - // argument's type is realistic and don't mislead backend compiler. - BaseMemRefType argType = dyn_cast(arg.getType()); - if (argType.getElementTypeBitWidth() == 1) { - // realistic memory layout of bool pointer is 8 bit width - auto memType = argType.cloneWith(std::nullopt, builder.getI8Type()); - arg.setType(memType); - inputTypes[i] = arg.getType(); - } - } - } - } - auto castType = FunctionType::get(func.getContext(), inputTypes, retTypes); - - auto funcFunc = builder.create(func.getLoc(), name, castType); - funcFunc.setAllArgAttrs(argAttrs); - funcFunc.setAllResultAttrs(resAttrs); - auto kernelAttr = func->getAttr(globalKernelAttr); - if (kernelAttr) { - funcFunc->setAttr(globalKernelAttr, kernelAttr); - } - std::string kernelMixMode = "aiv"; - if (existDot) { - // mix also works for pure cube kernel by using the same MAGIC_ELF keyword - kernelMixMode = "mix"; - } - // Set mix_mode in the func attrs so that the backend could know - // the mix_mode by parse the func attrs. - // The backend needs to know the mix_mode because the host wrapper - // needs to set the devbin.magic. Check npu_utils.cpp. - funcFunc->setAttr(kernelMixModeName, builder.getStringAttr(kernelMixMode)); - - auto &funcFuncBody = funcFunc.getBody(); - auto &funcBody = func.getBody(); - - IRMapping map; - funcBody.cloneInto(&funcFuncBody, map); - - for (Block &block : funcFuncBody.getBlocks()) { - auto term = block.getTerminator(); - builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); - term->erase(); - } - func.erase(); -} - -void TritonToLinalgPass::addDynamicLegal( - ConversionTarget &target, TritonTypeConverter &tritonTypeConverter) { - target.addLegalDialect< - func::FuncDialect, arith::ArithDialect, math::MathDialect, - linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, - cf::ControlFlowDialect, tensor::TensorDialect, - bufferization::BufferizationDialect, memref::MemRefDialect>(); - - // add legal dialect on condition - target.addLegalOp(); - - // 根据条件判断需要转换的OP - target.addDynamicallyLegalOp( - [](mlir::Operation *op) { - if (op->use_empty()) { - return false; - } else { - return true; - } - }); - - target.addDynamicallyLegalOp([&](triton::FuncOp op) { - return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); - }); - - target.addDynamicallyLegalOp([](arith::ConstantOp op) { - auto res = op.getResult(); - if (!isa(res.getType())) { - return true; - } - - if (auto denseAttr = dyn_cast(op.getValue())) { - if (!denseAttr.isSplat() || - !isa(denseAttr.getElementType())) { - return true; - } - if (res.hasOneUse() && isa(*res.user_begin())) { - return true; - } - return false; - } - return true; - }); - - target.addDynamicallyLegalOp([](Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](Type t) { - if (isa(t)) { - return false; - } - if (auto shapedType = dyn_cast(t)) { - return shapedType.getElementType().isIntOrFloat(); - } - assert(t.isIntOrIndexOrFloat()); - return true; - }); - }); - - target.addDynamicallyLegalDialect( - [this](Operation *op) { - if (op->hasAttr("MetaUse")) { - return false; - } - - if (isa(op)) { - return true; - } - - bool operateOnTensors = - llvm::all_of(op->getOperandTypes(), - [](Type type) { return isa(type); }); - - return this->namedOps || !operateOnTensors; - }); -} - -void TritonToLinalgPass::populateTritonToLinalgCanonicalizationPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); - patterns.add, - LoadStoreConverter::LoadStoreCanonicalizer>( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - // TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer, - TTOpConverters::ScalarMathCanonicalizer - // By test, the following ops do not need canonicalization. - // TTOpConverters::ScalarMathCanonicalizer - // TTOpConverters::ScalarMathCanonicalizer - // TTOpConverters::ScalarMathCanonicalizer - >(patterns.getContext()); -} - -void TritonToLinalgPass::populateTritonToLinalgConversionPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns, - unsigned int launchGridRank) { - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - // reduce converters - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - if (!this->namedOps) { - linalg::populateElementwiseToLinalgConversionPatterns(patterns); - } -} - -void TritonToLinalgPass::getDependentDialects(DialectRegistry ®istry) const { - registry.insert(); -} - -void TritonToLinalgPass::runOnOperation() { - auto moduleOp = getOperation(); - - // Check if the kernel contains tl.dot. Without tl.dot, - // the kernel would be pure AIV kernel. - bool existDot = false; - moduleOp.walk([&](triton::DotOp dotOp) { - existDot = true; - return WalkResult::interrupt(); - }); - - RewritePatternSet canonicalizerPatterns(&getContext()); - // 1.标准化 LoadStore ScalarStoreCanonicalizer - this->populateTritonToLinalgCanonicalizationPatterns(canonicalizerPatterns); - if (failed(applyPatternsAndFoldGreedily(moduleOp, - std::move(canonicalizerPatterns)))) { - moduleOp->emitError("failed to apply Canonicalizer Patterns"); - signalPassFailure(); - } - - // 2.使用分析 - moduleOp.walk([this](triton::FuncOp op) { - if (failed(runUseAnalysis(op))) { - signalPassFailure(); - } - }); - - RewritePatternSet patterns(&getContext()); - ConversionTarget target(getContext()); - TritonTypeConverter tritonTypeConverter{}; - - // 3.标注合法方言 - this->addDynamicLegal(target, tritonTypeConverter); - - // 5.对非法Op注册Converter - this->populateTritonToLinalgConversionPatterns(tritonTypeConverter, patterns, - LAUNCH_GRID_RANK); - - // 6.遍历kernel中的function,修改program id、number of programs参数 - for (auto func : getOperation().getOps()) { - addProgramInfo(func, globalKernel); - } - - // 7.做Op转换 - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { - moduleOp->emitError("failed to apply Convertion Patterns"); - signalPassFailure(); - } - - // 8.函数头尾转换 - moduleOp.walk( - [&](triton::FuncOp func) { this->convertTTFunc(func, existDot); }); - - // 9.清除无效代码,简化代码。 - PassManager pm(&getContext(), moduleOp.getOperationName()); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - if (failed(runPipeline(pm, getOperation()))) { - signalPassFailure(); - } - - // Try interleave optimization - llvm::DenseMap> interleaveCandidate; - llvm::DenseMap> - interleaveCandidateWithMask; - moduleOp.walk([&](bufferization::MaterializeInDestinationOp materializeOp) { - if (auto reinterpretCastOp = - materializeOp.getDest() - .getDefiningOp()) { - if (llvm::isa(reinterpretCastOp.getSource()) && - reinterpretCastOp.getStaticStrides().back() == 2) { - interleaveCandidate[llvm::cast( - reinterpretCastOp.getSource())] - .push_back(materializeOp); - } - } - - // Difference is that converted op chain of store with mask has - // `memref::SubViewOp` - if (auto subviewOp = - materializeOp.getDest().getDefiningOp()) { - if (!llvm::isa( - materializeOp.getSource().getDefiningOp())) - return WalkResult::advance(); - - if (auto reinterpretCastOp = - subviewOp.getSource() - .getDefiningOp()) { - if (llvm::isa(reinterpretCastOp.getSource()) && - reinterpretCastOp.getStaticStrides().back() == 2) { - interleaveCandidateWithMask[llvm::cast( - reinterpretCastOp.getSource())] - .push_back(materializeOp); - } - } - } - - return WalkResult::advance(); - }); - - for (auto [blockArg, materializeVec] : interleaveCandidate) { - // Just enable optimization where exists double materializeOp with same - // block argument destination. - if (materializeVec.size() != 2) - continue; - auto result = InterleaveStatusOptimization(materializeVec); - } - - for (auto [blockArg, materializeVec] : interleaveCandidateWithMask) { - if (materializeVec.size() != 2) - continue; - auto result = InterleaveStatusWithMaskOptimization(materializeVec); - } - - // Force to add an argument at the beginning of function arguments, which - // represents stub arg for workspace. Default type is memref - for (auto func : getOperation().getOps()) { - if (!func->hasAttr("global_kernel")) - continue; - - auto context = func.getContext(); - constexpr int64_t workspaceArgIdx = 0; - MemRefType workspaceArgType = - MemRefType::get(SmallVector(1, ShapedType::kDynamic), - IntegerType::get(context, 8)); - NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"), - UnitAttr::get(context)); - - func.insertArgument(/*argIndex*/ workspaceArgIdx, - /*argType*/ workspaceArgType, - /*dicAttr*/ nullptr, func->getLoc()); - func->setAttr("WorkspaceArgIdx", - IntegerAttr::get(IntegerType::get(&getContext(), 64), 0)); - } - - // Fix the Location info - moduleOp.walk([&](Operation *op) { - auto loc = op->getLoc(); - if (isa(loc)) { - llvm::SmallPtrSet stopOps; - traverseForwardUpdateUserChainIf( - op, - /*conditionFn*/ - [](Operation *curOp) { return false; }, - /*stopFn*/ - [](Operation *curOp) { return !isa(curOp->getLoc()); }, - /*actionFn*/ - nullptr, stopOps); - if (stopOps.empty()) { - op->emitWarning() << *op << " and its users all have no location!"; - } else { - Operation *goodOp = *stopOps.begin(); - op->setLoc(goodOp->getLoc()); - } - } - return WalkResult::advance(); - }); -} - -std::unique_ptr> triton::createTritonToLinalgPass() { - return std::make_unique(); -} diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp deleted file mode 100644 index 4b096316d..000000000 --- a/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp +++ /dev/null @@ -1,362 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "TritonToLinalg/UseAnalysis.h" -#include "Utils/Utils.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" - -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; -using namespace triton; -using namespace dataflow; - -#define DEBUG_TYPE "triton-use-analysis" - -std::string stringifyUseType(UseType useTy) { - std::string ret; - if (useTy == UseType::MetaUse) { - ret = "MetaUse"; - } else if (useTy == UseType::DataUse) { - ret = "DataUse"; - } else if (useTy == UseType::MixUse) { - ret = "MixUse"; - } else if (useTy == UseType::Undefined) { - ret = "Undefined"; - } - return ret; -} - -#if LLVM_VERSION_MAJOR >= 20 -LogicalResult -triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) { -#else -void triton::UseAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { -#endif - - if (op->getResults().size() == 1) { - auto resultType = dyn_cast(op->getResult(0).getType()); - if (resultType && isa(resultType.getElementType())) { - for (auto opnd : operands) { - propagateUse(opnd, UseType::MetaUse); - } - } - } - - TypeSwitch(op) - .Case([&](auto load) { - propagateUse(operands[0], UseType::MetaUse); - auto mask = load.getMask(); - auto other = load.getOther(); - if (mask) { - assert(mask != other && "mask and other cannot be the same"); - propagateUse(operands[1], UseType::MetaUse); - } - if (other) { - propagateUse(operands[2], UseType::MetaUse); - } - }) - .Case([&](auto store) { - propagateUse(operands[0], UseType::MetaUse); - propagateUse(operands[1], UseType::DataUse); - auto value = store.getValue(); - auto mask = store.getMask(); - if (mask) { - assert(mask != value && "mask and data cannot be the same"); - propagateUse(operands[2], UseType::MetaUse); - } - }) - // Consider triton::AtomicRMWOp as store operation - .Case([&](auto atomicOp) { - propagateUse(operands[0], UseType::MetaUse); - propagateUse(operands[1], UseType::DataUse); - auto value = atomicOp.getVal(); - auto mask = atomicOp.getMask(); - if (mask) { - assert(mask != value && "mask and data cannot be the same"); - propagateUse(operands[2], UseType::MetaUse); - } - }) - .Case([&](auto dot) { - propagateResults(operands[0], results); - propagateResults(operands[1], results); - - auto opc = dot.getC(); - triton::SplatOp splat; - if (opc) { - splat = opc.template getDefiningOp(); - } - - if (opc && splat && splat.getSrc().getDefiningOp()) { - propagateUse(operands[2], UseType::MetaUse); - } else { - propagateUse(operands[2], UseType::DataUse); - } - }) - .Default([&](Operation *op) { - // this condition account for tt.addptr - for (auto operand : operands) { - propagateResults(operand, results); - } - }); -#if LLVM_VERSION_MAJOR >= 20 - return success(); -#endif -} - -LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { - MLIRContext *context = funcOp.getContext(); - SymbolTableCollection symbolTable; - - DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(symbolTable); - if (failed(solver.initializeAndRun(funcOp))) { - return failure(); - } - auto &os = llvm::dbgs(); - // Walk the func op, convert tags on operands to tags on operations - funcOp.walk([&](Operation *op) { - LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; }); - UseType useType = UseType::Undefined; - for (auto result : op->getResults()) { - LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; }); - auto use = solver.lookupState(result); - assert(use && "Lattice value not found"); - auto thisUseType = use->type; - LLVM_DEBUG({ - os << "[UseAnalysis] ==========> useType is " - << stringifyUseType(thisUseType) << "\n"; - }); - if (thisUseType == UseType::Undefined) { - continue; - } - if (useType == UseType::Undefined) { - useType = thisUseType; - } - if (thisUseType == UseType::MixUse || thisUseType != useType) { - useType = UseType::MixUse; - break; - } - } - - if (useType == UseType::Undefined) { - LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); - return; - } else if (useType == UseType::MetaUse) { - if (!isa(op)) { - assert(op->getNumResults() == 1 && - "Ops used for meta computation are expected to have one result"); - } - for (auto it = 0; it < op->getNumResults(); ++it) { - // Only set the tag if the operation uses tensors - if (isa(op->getResult(it).getType()) || - (isa(op) && - isa(op->getResult(it).getType()))) { - // Setting tag for erasing op later - op->setAttr("MetaUse", UnitAttr::get(context)); - } - } - return; - } else if (useType == UseType::DataUse) { - LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); - return; - } - - assert(useType == UseType::MixUse); - - // If the operation only produces scalars, no need to clone it - bool shapedResult = true; - for (auto result : op->getResults()) - shapedResult &= isa(result.getType()); - if (!shapedResult) { - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); - return; - } - - llvm::SetVector metaUsers; - for (auto result : op->getResults()) { - for (auto user : result.getUsers()) { - TypeSwitch(user) - .Case([&](auto load) { - auto ptr = load.getPtr(); - auto mask = load.getMask(); - auto other = load.getOther(); - if (result == ptr || result == mask || result == other) { - metaUsers.insert(user); - } - }) - .Case([&](auto store) { - auto ptr = store.getPtr(); - auto mask = store.getMask(); - if (result == ptr || result == mask) { - metaUsers.insert(user); - } - }) - .Case([&](auto atomicOp) { - auto ptr = atomicOp.getPtr(); - auto mask = atomicOp.getMask(); - if (result == ptr || result == mask) - metaUsers.insert(user); - }) - .Case([&](auto dot) { - auto opc = dot.getC(); - triton::SplatOp splat; - if (opc) { - splat = opc.template getDefiningOp(); - } - - if (opc && splat && - splat.getSrc().getDefiningOp()) { - metaUsers.insert(user); - } - }) - .Default([&](Operation *op) { - bool allMeta = true; - for (auto res : op->getResults()) { - auto resUse = solver.lookupState(res); - if (resUse->type != UseType::MetaUse) { - allMeta = false; - break; - } - } - if (allMeta) { - metaUsers.insert(user); - } - }); - } - } - - // If the operation doesn't have direct meta users, no need to clone it - if (metaUsers.empty()) { - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); - return; - } - - // Clone the operation; switch all meta users to use the clone - OpBuilder builder(op); - auto clone = builder.clone(*op); - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); - - // Setting tag for erasing op later - clone->setAttr("MetaUse", UnitAttr::get(context)); - - for (auto [res_i, result] : llvm::enumerate(op->getResults())) { - for (auto user : metaUsers) { - for (auto &operand : user->getOpOperands()) { - if (operand.get() == result) { - operand.set(clone->getResult(res_i)); - } - } - } - } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n"; - }); - // Post-process - funcOp.walk([&](Operation *op) { - // Handle indirect load case. - // For example, load(1st) -> computeOp -> load(2nd). - // The first load is IndirectLoadInterfaceOp. - // Do not inplace replace MetaUse by MixUse. Because the condition checking - // depends on that the op has the attr of MetaUse. - // Handle the indirect load interface op - // We first trace from the 1st load to the 2nd load with the ops between - // them marked as MixUse. Then we traceback from the 2nd load to mark defs - // MixUse. - if (opIsIndirectLoad(op) || opIsIndirectCalc(op)) { - LLVM_DEBUG({ - os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n"; - }); - llvm::SmallPtrSet stopOps; - // Modify the users of this op's result. - traverseForwardUpdateUserChainIf( - op, - /*conditionFn*/ - [op](Operation *curOp) { return isMetaUse(curOp) && curOp != op; }, - /*stopFn*/ - [&](Operation *curOp) { - // triton::LoadOp without MetaUse means it is an indirect load - // instead of the load providing the offset. - // The pattern is as follows, - // load -> ops -> load - // We need to ensure the intermediate ops are marked MixUse - // so that they will be replaced instead of be erased without - // conversion. - return isa(curOp) && !curOp->hasAttr("MetaUse"); - }, - /*actionFn*/ - [](OpBuilder &b, Operation *op) { - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - }, - stopOps); - LLVM_DEBUG({ - os << "[UseAnalysis] stopOps are \n"; - int i = 0; - for (auto it = stopOps.begin(); it != stopOps.end(); it++) { - os << i++ << ": " << *(*it) << "\n"; - } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] After trace, funcOp is " << *funcOp << "\n"; - }); - for (auto it = stopOps.begin(); it != stopOps.end(); it++) { - auto stopOp = *it; - traverseBackwardUpdateOperandChainIf( - stopOp, - [stopOp](Operation *curOp) { - return isMetaUse(curOp) && curOp != stopOp; - }, - [](OpBuilder &b, Operation *op) { - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - }); - } - LLVM_DEBUG({ - os << "[UseAnalysis] After traceback of stopOp, funcOp is " << *funcOp - << "\n"; - }); - // Modify this op. - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - } - }); - // Remove MetaUse in case of MixUse existing in the op - funcOp.walk([&](Operation *op) { - if (isMetaUse(op) && isMixUse(op)) { - op->removeAttr("MetaUse"); - } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n"; - }); - return success(); -} - -MetaUseEraser::MetaUseEraser(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} - -LogicalResult MetaUseEraser::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - if (isa(op)) { - return rewriter.notifyMatchFailure(op, - "AddPtrOp will be handled separately"); - } - if (isMetaUse(op)) { - rewriter.eraseOp(op); - return success(); - } - return rewriter.notifyMatchFailure(op, "requires meta ops"); -} diff --git a/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt deleted file mode 100644 index b6aa5164b..000000000 --- a/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_triton_library(MLIRTritonNPUUtils - Utils.cpp - InterleaveOptimization.cpp - - LINK_LIBS PUBLIC - MLIRIR - TritonIR -) diff --git a/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp b/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp deleted file mode 100644 index ec3e4d3d5..000000000 --- a/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp +++ /dev/null @@ -1,662 +0,0 @@ -//===- InterleaveOptimization.cpp -------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "Utils/InterleaveOptimization.h" -#include "Utils/Utils.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/ViewLikeInterface.h" -#include "mlir/Support/LogicalResult.h" - -#include "mlir/IR/Operation.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include -#include - -namespace mlir { -namespace triton { -// For origin MemRefType of ReinterpretCastOp under interleave state, here wanna -// adjust its shape info by expanding last dimension double. -MemRefType expandInterleaveMemRefType(MemRefType originType) { - // Double the last dimension shape - SmallVector shape(originType.getShape()); - shape.back() = shape.back() * 2; - - // Adjuest layout attribute - StridedLayoutAttr originLayout = - llvm::dyn_cast(originType.getLayout()); - // If offset is static, just reset it to 0 - auto offset = originLayout.getOffset() == ShapedType::kDynamic - ? originLayout.getOffset() - : 0; - // Set last dimension stride to 1 - SmallVector stride(originLayout.getStrides()); - stride.back() = 1; - - return MemRefType::get( - shape, originType.getElementType(), - StridedLayoutAttr::get(originType.getContext(), offset, stride)); -} - -// ********************* -// ** NOTE ** -// ********************* -// How to determine new offset is a little tricky and specific -// Here just consider this state in triton language: -// -// dim_range = tl.arange(0, BLOCK // 2) -// last_dim_even_range = dim_range * 2 -// last_dim_odd_range = dim_range * 2 + 1 -// -// Here `multiply two` represents that last dimension stride is 2, and -// `add constant one` represents whether it's odd index part of -// deinterleave result. -// -// Therefore, how to distinguish interleave/deinterleave on even index or odd -// index is whether last dimension range explicitly `add constant one` without -// any other operation. In IR it's shown that whether defining op of -// `castOffset` is an arith::addOp, as this arith::addOp would contain above -// `add constant one` opeartion after LegacyAddPtrConverter. -// -// Well, index mode should be passed to interleave/deinterleave, in other words, -// `add constant one` should work on offset of next insert_slice/extract_slic. -// The new reinterpretcast just wanna describe whole tensor, so new castOffset -// is just from non-last diemsnion accumulation and remove `add constant one` -std::pair -recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder) { - // To trace value type offset - std::function traceOffset = [&](Operation *op) -> bool { - // Consider constant one in `add constant one` operation - if (llvm::isa(op)) - return false; - - if (llvm::isa(op)) { - auto addOp = llvm::cast(op); - if (auto constLHS = addOp.getLhs().getDefiningOp()) { - assert(dyn_cast(constLHS.getValueAttr()).getInt() == 1 && - "Arith::constant value of addi's operand must be 1 when " - "calculate deinterleave offset"); - return false; - } - if (auto constRHS = addOp.getRhs().getDefiningOp()) { - assert(dyn_cast(constRHS.getValueAttr()).getInt() == 1 && - "Arith::constant value of addi's operand must be 1 when " - "calculate deinterleave offset"); - return false; - } - } - return true; - }; - - IndexMode evenOrOdd = IndexMode::EVEN_MODE; - // Reuse origin offset if there's no 'add constant one' - OpFoldResult newOffset = originOffset; - if (llvm::isa(originOffset)) { - // If offset is constant int(IndexAttr), - // the int value could only be 0 or 1 - int64_t intOffset = getConstantIntValue(originOffset).value(); - assert((intOffset == 0 || intOffset == 1)); - if (intOffset == 1) { - evenOrOdd = IndexMode::ODD_MODE; - newOffset = builder.getIndexAttr(0); - } - } else if (llvm::isa(originOffset)) { - if (!traceOffset(originOffset.get().getDefiningOp())) { - evenOrOdd = IndexMode::ODD_MODE; - Operation *traceResult = findFirstMatchingOperandDef( - originOffset.get().getDefiningOp(), traceOffset); - assert(traceResult->getNumResults() == 1 && - "Offset defining operation must have one result"); - newOffset = traceResult->getResult(0); - } - } - - return {newOffset, evenOrOdd}; -} - -LogicalResult -DeinterleaveStatusOptimization(triton::LoadOp op, - triton::LoadOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) { - auto ptr = adaptor.getPtr(); - if (auto reinterpretCast = ptr.getDefiningOp()) { - auto loc = op.getLoc(); - - // 1. Get new source memref type - auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); - - // 2. Create new ReinterpretCastOp - auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); - auto castSize = reinterpretCast.getConstifiedMixedSizes(); - auto castStride = reinterpretCast.getConstifiedMixedStrides(); - // Actually, `castSize` is always constant value as `MemRefType` result - if (auto lastDimSize = makeIntAttr(castSize.back())) { - castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); - } else { - return failure(); - } - // Last element of castStride is also constant value as prerequisite - // is that last dimension stride of casted memref type is always 2. - castStride.back() = rewriter.getIndexAttr(1); - auto [castOffset, indexMode] = - recountReinterpretCastOffset(originCastOffset, rewriter); - auto newCastOp = rewriter.create( - loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, - castStride); - - // 3. Create new memref allocOp - auto newAllocOp = rewriter.create( - loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); - - // 4. Implement memref copy and bufferization back to tensor - rewriter.create(loc, newCastOp.getResult(), newAllocOp); - Value newTensor = rewriter.create( - loc, - RankedTensorType::get(srcType.getShape(), srcType.getElementType()), - newAllocOp, true /* restrict */, true /* writable */); - - // 5. Implement tensor extract_slice to represent deinterleave - // Here use `castOffset` to determine whether even index deinterleave or - // odd index. - SmallVector extractOffsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector extractStrides(srcType.getRank(), - rewriter.getIndexAttr(1)); - SmallVector extractSizes = llvm::to_vector( - llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); - - // Adjust extract_slice shape - switch (indexMode) { - case IndexMode::EVEN_MODE: - extractOffsets.back() = rewriter.getIndexAttr(0); - break; - case IndexMode::ODD_MODE: - extractOffsets.back() = rewriter.getIndexAttr(1); - break; - } - extractStrides.back() = rewriter.getIndexAttr(2); - extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); - - Value deinterleaveSlice = rewriter.create( - loc, newTensor, extractOffsets, extractSizes, extractStrides); - - rewriter.replaceOp(op, deinterleaveSlice); - return success(); - } - - return failure(); -} - -LogicalResult DeinterleaveStatusWithMaskOptimization( - triton::LoadOp op, triton::LoadOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter, MaskState &mstate, - memref::AllocOp originAllocOp) { - auto ptr = adaptor.getPtr(); - if (auto reinterpretCast = ptr.getDefiningOp()) { - auto loc = op.getLoc(); - - // 1. Get new source memref type - auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); - - // 2. Create new ReinterpretCastOp - auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); - auto castSize = reinterpretCast.getConstifiedMixedSizes(); - auto castStride = reinterpretCast.getConstifiedMixedStrides(); - - if (auto lastDimSize = makeIntAttr(castSize.back())) { - castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); - } else { - return failure(); - } - castStride.back() = rewriter.getIndexAttr(1); - auto [castOffset, indexMode] = - recountReinterpretCastOffset(originCastOffset, rewriter); - - auto newCastOp = rewriter.create( - loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, - castStride); - - // 3. Create new memref allocOp - // To reuse existing linalg::fill, here need to change insertion point - auto savedInsertPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(originAllocOp); - auto newAllocOp = rewriter.create( - loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); - rewriter.restoreInsertionPoint(savedInsertPoint); - - // 4. Broadcast other value by linalg.fill if necessary - auto other = op.getOther(); - // While deinterleave optimization will just adjust last dimension info - // and origin mask state wouldn't involve last dimension. Therefore in - // current `scf.if + linalg.fill` combination, condition of `if` could be - // kept and just replace linalg.fill' - if (other) { - assert(originAllocOp->hasOneUse() && - llvm::isa(*(originAllocOp->getUsers().begin()))); - auto originFillOp = - llvm::dyn_cast(*(originAllocOp->getUsers().begin())); - - assert(llvm::isa(originFillOp->getParentOp())); - auto ifOp = llvm::dyn_cast(originFillOp->getParentOp()); - - auto newFillOp = ifOp.getThenBodyBuilder().create( - originFillOp.getLoc(), originFillOp.getInputs(), - ValueRange{newAllocOp}); - rewriter.eraseOp(originFillOp); - } - - // 5. Implement new subview, memref copy and bufferization back to tensor - SmallVector subviewStrides(srcType.getRank(), - rewriter.getIndexAttr(1)); - SmallVector subviewOffsets = mstate.offsets; - SmallVector subviewSizes = mstate.dims; - // Just adjust last dimension size to double - std::optional originSubviewLastDim = - getConstantIntValue(subviewSizes.back()); - assert(originSubviewLastDim.has_value()); - subviewSizes.back() = - rewriter.getIndexAttr(originSubviewLastDim.value() * 2); - - auto argSubviewType = memref::SubViewOp::inferResultType( - srcType, subviewOffsets, subviewSizes, subviewStrides); - // alloca subview type doesn't carry layout attribute - auto allocSubviewType = memref::SubViewOp::inferResultType( - newAllocOp.getType(), subviewOffsets, subviewSizes, subviewStrides); - - memref::SubViewOp srcSubview = rewriter.create( - loc, llvm::cast(argSubviewType), newCastOp, subviewOffsets, - subviewSizes, subviewStrides); - memref::SubViewOp dstSubview = rewriter.create( - loc, llvm::cast(allocSubviewType), newAllocOp, - subviewOffsets, subviewSizes, subviewStrides); - rewriter.create(loc, srcSubview, dstSubview); - Value newTensor = rewriter.create( - loc, - RankedTensorType::get(srcType.getShape(), srcType.getElementType()), - newAllocOp, true /* restrict */, true /* writable */); - - // 6. Implement tensor extract_slice to represent deinterleave - // Here use `castOffset` to determine whether even index deinterleave or - // odd index. - SmallVector extractOffsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector extractStrides(srcType.getRank(), - rewriter.getIndexAttr(1)); - SmallVector extractSizes = llvm::to_vector( - llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); - - switch (indexMode) { - case IndexMode::EVEN_MODE: - extractOffsets.back() = rewriter.getIndexAttr(0); - break; - case IndexMode::ODD_MODE: - extractOffsets.back() = rewriter.getIndexAttr(1); - break; - } - extractStrides.back() = rewriter.getIndexAttr(2); - extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); - - Value deinterleaveSlice = rewriter.create( - loc, newTensor, extractOffsets, extractSizes, extractStrides); - - rewriter.replaceOp(op, deinterleaveSlice); - return success(); - } - return failure(); -} - -LogicalResult -InterleaveStatusOptimization(SmallVector materializeVec) { - OpBuilder builder(materializeVec[1]); - auto loc = materializeVec[1]->getLoc(); - - auto firstReinterpretCastOp = - llvm::dyn_cast( - materializeVec[0]) - .getDest() - .getDefiningOp(); - auto secondReinterpretCastOp = - llvm::dyn_cast( - materializeVec[1]) - .getDest() - .getDefiningOp(); - - assert(firstReinterpretCastOp && secondReinterpretCastOp); - - // Judge whether two `ReinterpretCastOp` shape satisfy interleave state - // a. both size are equal - if (!isEqualConstantIntOrValueArray( - firstReinterpretCastOp.getConstifiedMixedSizes(), - secondReinterpretCastOp.getConstifiedMixedSizes())) { - return failure(); - } - // b. both strides are equal - if (!isEqualConstantIntOrValueArray( - firstReinterpretCastOp.getConstifiedMixedStrides(), - secondReinterpretCastOp.getConstifiedMixedStrides())) { - return failure(); - } - // c. both offsets should satisfy tricky rule - auto firstOriginCastOffset = - firstReinterpretCastOp.getConstifiedMixedOffset(); - auto secondOriginCastOffset = - secondReinterpretCastOp.getConstifiedMixedOffset(); - std::pair indexModeRecord; - OpFoldResult newCastOffset; - if (llvm::isa(firstOriginCastOffset) && - llvm::isa(secondOriginCastOffset)) { - auto [firstCastOffset, firstIndexMode] = - recountReinterpretCastOffset(firstOriginCastOffset, builder); - auto [secondCastOffset, secondIndexMode] = - recountReinterpretCastOffset(secondOriginCastOffset, builder); - - if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) - return failure(); - newCastOffset = builder.getIndexAttr(0); - indexModeRecord = {firstIndexMode, secondIndexMode}; - - } else if (llvm::isa(firstOriginCastOffset) && - llvm::isa(secondOriginCastOffset)) { - auto [firstCastOffset, firstIndexMode] = - recountReinterpretCastOffset(firstOriginCastOffset, builder); - auto [secondCastOffset, secondIndexMode] = - recountReinterpretCastOffset(secondOriginCastOffset, builder); - - if (!(static_cast(firstIndexMode) ^ - static_cast(secondIndexMode)) || - (llvm::dyn_cast(firstCastOffset) != - llvm::dyn_cast(secondCastOffset))) - return failure(); - - if (firstIndexMode == IndexMode::EVEN_MODE) { - newCastOffset = llvm::dyn_cast(firstCastOffset); - } - if (secondIndexMode == IndexMode::EVEN_MODE) { - newCastOffset = llvm::dyn_cast(secondCastOffset); - } - indexModeRecord = {firstIndexMode, secondIndexMode}; - - } else { - return failure(); - } - - // Create new op - // 1. Get new destination memref type - auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); - - // 2. New tensor::EmptyOp - auto emptyTensor = builder.create(loc, dstType.getShape(), - dstType.getElementType()); - - // 3. New insert_slice from materialization source into new empty tensor - SmallVector insertOffsets(dstType.getRank(), - builder.getIndexAttr(0)); - SmallVector insertStrides(dstType.getRank(), - builder.getIndexAttr(1)); - SmallVector insertSizes = llvm::to_vector( - llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { - return builder.getIndexAttr(dim); - })); - insertStrides.back() = builder.getIndexAttr(2); - insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); - if (indexModeRecord.first == IndexMode::ODD_MODE) { - insertOffsets.back() = builder.getIndexAttr(1); - } else { - insertOffsets.back() = builder.getIndexAttr(0); - } - auto insertFirst = builder.create( - loc, - llvm::dyn_cast( - materializeVec[0]) - .getSource(), - emptyTensor.getResult(), insertOffsets, insertSizes, insertStrides); - - if (indexModeRecord.second == IndexMode::ODD_MODE) { - insertOffsets.back() = builder.getIndexAttr(1); - } else { - insertOffsets.back() = builder.getIndexAttr(0); - } - auto insertSecond = builder.create( - loc, - llvm::dyn_cast( - materializeVec[1]) - .getSource(), - insertFirst.getResult(), insertOffsets, insertSizes, insertStrides); - - // 4. Reinterpret_cast block arg - auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); - auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); - newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); - newCastStride.back() = builder.getIndexAttr(1); - auto newCastOp = builder.create( - loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, - newCastSize, newCastStride); - - // 5. Create new bufferization::MaterializeInDestinationOp - auto newStoreOp = builder.create( - loc, insertSecond.getResult(), newCastOp.getResult()); - // Setting writable is necessary as dst is memref type - newStoreOp.setWritable(true); - - // 6. Erase origin materialization - materializeVec[0]->erase(); - materializeVec[1]->erase(); - - return success(); -} - -LogicalResult -InterleaveStatusWithMaskOptimization(SmallVector materializeVec) { - OpBuilder builder(materializeVec[1]); - - auto firstSubviewOpOfReCast = - llvm::dyn_cast( - materializeVec[0]) - .getDest() - .getDefiningOp(); - auto firstSrcExtractSlice = - llvm::dyn_cast( - materializeVec[0]) - .getSource() - .getDefiningOp(); - auto firstReinterpretCastOp = firstSubviewOpOfReCast.getSource() - .getDefiningOp(); - - auto secondSubviewOpOfReCast = - llvm::dyn_cast( - materializeVec[1]) - .getDest() - .getDefiningOp(); - auto secondSrcExtractSlice = - llvm::dyn_cast( - materializeVec[1]) - .getSource() - .getDefiningOp(); - auto secondReinterpretCastOp = - secondSubviewOpOfReCast.getSource() - .getDefiningOp(); - - // 1. Both source shapes of subview and extract_slice are equal - if (firstSubviewOpOfReCast.getSourceType().getShape() != - firstSrcExtractSlice.getSourceType().getShape()) - return failure(); - if (secondSubviewOpOfReCast.getSourceType().getShape() != - secondSrcExtractSlice.getSourceType().getShape()) - return failure(); - if (firstSubviewOpOfReCast.getSourceType().getShape() != - secondSubviewOpOfReCast.getSourceType().getShape()) - return failure(); - - // 2. both mask state are equal - std::function cmpFunc = - mlir::isEqualConstantIntOrValue; - if (!mlir::detail::sameOffsetsSizesAndStrides(firstSubviewOpOfReCast, - firstSrcExtractSlice, cmpFunc)) - return failure(); - if (!mlir::detail::sameOffsetsSizesAndStrides(secondSubviewOpOfReCast, - secondSrcExtractSlice, cmpFunc)) - return failure(); - if (!mlir::detail::sameOffsetsSizesAndStrides( - firstSubviewOpOfReCast, secondSubviewOpOfReCast, cmpFunc)) - return failure(); - - // 3. Still judge whether two `ReinterpretCastOp` shape satisfy request - // a. both size are equal - if (!isEqualConstantIntOrValueArray( - firstReinterpretCastOp.getConstifiedMixedSizes(), - secondReinterpretCastOp.getConstifiedMixedSizes())) - return failure(); - // b. both strides are equal - if (!isEqualConstantIntOrValueArray( - firstReinterpretCastOp.getConstifiedMixedStrides(), - secondReinterpretCastOp.getConstifiedMixedStrides())) - return failure(); - // c. both offsets should satisfy tricky rule - auto firstOriginCastOffset = - firstReinterpretCastOp.getConstifiedMixedOffset(); - auto secondOriginCastOffset = - secondReinterpretCastOp.getConstifiedMixedOffset(); - std::pair indexModeRecord; - OpFoldResult newCastOffset; - if (llvm::isa(firstOriginCastOffset) && - llvm::isa(secondOriginCastOffset)) { - auto [firstCastOffset, firstIndexMode] = - recountReinterpretCastOffset(firstOriginCastOffset, builder); - auto [secondCastOffset, secondIndexMode] = - recountReinterpretCastOffset(secondOriginCastOffset, builder); - - if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) - return failure(); - newCastOffset = builder.getIndexAttr(0); - indexModeRecord = {firstIndexMode, secondIndexMode}; - - } else if (llvm::isa(firstOriginCastOffset) && - llvm::isa(secondOriginCastOffset)) { - auto [firstCastOffset, firstIndexMode] = - recountReinterpretCastOffset(firstOriginCastOffset, builder); - auto [secondCastOffset, secondIndexMode] = - recountReinterpretCastOffset(secondOriginCastOffset, builder); - - if (!(static_cast(firstIndexMode) ^ - static_cast(secondIndexMode)) || - (llvm::dyn_cast(firstCastOffset) != - llvm::dyn_cast(secondCastOffset))) - return failure(); - - if (firstIndexMode == IndexMode::EVEN_MODE) { - newCastOffset = llvm::dyn_cast(firstCastOffset); - } - if (secondIndexMode == IndexMode::EVEN_MODE) { - newCastOffset = llvm::dyn_cast(secondCastOffset); - } - indexModeRecord = {firstIndexMode, secondIndexMode}; - - } else { - return failure(); - } - auto loc = materializeVec[1]->getLoc(); - - // Create new op - // 1. Get new destination memref type - auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); - - // 2. New tensor::EmptyOp - auto emptyTensor = builder.create(loc, dstType.getShape(), - dstType.getElementType()); - - // 3. New insert_slice from extract_slice source into new empty tensor - SmallVector insertOffsets(dstType.getRank(), - builder.getIndexAttr(0)); - SmallVector insertStrides(dstType.getRank(), - builder.getIndexAttr(1)); - SmallVector insertSizes = llvm::to_vector( - llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { - return builder.getIndexAttr(dim); - })); - insertStrides.back() = builder.getIndexAttr(2); - insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); - if (indexModeRecord.first == IndexMode::ODD_MODE) { - insertOffsets.back() = builder.getIndexAttr(1); - } else { - insertOffsets.back() = builder.getIndexAttr(0); - } - auto insertFirst = builder.create( - loc, firstSrcExtractSlice.getSource(), emptyTensor.getResult(), - insertOffsets, insertSizes, insertStrides); - - if (indexModeRecord.second == IndexMode::ODD_MODE) { - insertOffsets.back() = builder.getIndexAttr(1); - } else { - insertOffsets.back() = builder.getIndexAttr(0); - } - auto insertSecond = builder.create( - loc, secondSrcExtractSlice.getSource(), insertFirst.getResult(), - insertOffsets, insertSizes, insertStrides); - - // 4. To enable store with mask, create new extract_slice - SmallVector extractOffsets = - firstSrcExtractSlice.getMixedOffsets(); - SmallVector extractStrides = - firstSrcExtractSlice.getMixedStrides(); - SmallVector extractSizes = firstSrcExtractSlice.getMixedSizes(); - assert(llvm::isa(extractSizes.back())); - extractSizes.back() = builder.getIndexAttr( - getConstantIntValue(extractSizes.back()).value() * 2); - auto newSrcExtractSlice = builder.create( - loc, insertSecond.getResult(), extractOffsets, extractSizes, - extractStrides); - - // 5. Reinterpret_cast block arg - auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); - auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); - newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); - newCastStride.back() = builder.getIndexAttr(1); - auto newCastOp = builder.create( - loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, - newCastSize, newCastStride); - - // 6. Create new memref::SubViewOp of above new reinterpret_cast - // Here could reuse shape info of new extract_slice - auto dstSubviewType = memref::SubViewOp::inferResultType( - dstType, extractOffsets, extractSizes, extractStrides); - auto newSubviewOpOfReCast = builder.create( - loc, llvm::cast(dstSubviewType), newCastOp, extractOffsets, - extractSizes, extractStrides); - - // 7. Create new bufferization::MaterializeInDestinationOp - auto newStoreOp = builder.create( - loc, newSrcExtractSlice.getResult(), newSubviewOpOfReCast.getResult()); - // Setting writable is necessary as dst is memref type - newStoreOp.setWritable(true); - - // 8. Erase origin operation - materializeVec[0]->erase(); - materializeVec[1]->erase(); - firstSubviewOpOfReCast->erase(); - firstSrcExtractSlice->erase(); - secondSubviewOpOfReCast->erase(); - secondSrcExtractSlice->erase(); - - return success(); -} - -} // namespace triton -} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp b/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp deleted file mode 100644 index 805e69025..000000000 --- a/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp +++ /dev/null @@ -1,752 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "../../include/Utils/Utils.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include -#include - -#define DEBUG_TYPE "TritonNPU-Utils" - -namespace mlir { - -namespace ConverterUtils { - -Value getTransposedValue(Value source, const Location loc, - ConversionPatternRewriter &rewriter, - llvm::ArrayRef order) { - auto sourceType = cast(source.getType()); - auto sourceRank = sourceType.getRank(); - - SmallVector perm(order); - SmallVector originalShape(sourceType.getShape()); - SmallVector transposedShape(sourceRank); - for (size_t i = 0; i < sourceRank; i++) { - transposedShape[i] = originalShape[perm[i]]; - } - - Value transposeInit = rewriter.create( - loc, transposedShape, sourceType.getElementType()); - - Value transpose = - rewriter.create(loc, source, transposeInit, perm) - .getResults()[0]; - - return transpose; -} - -SmallVector getNParallelLoopsAttrs(unsigned n) { - return SmallVector(n, utils::IteratorType::parallel); -} - -Value getScalarValue(Value operand, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector ops; - auto reconstructScalarValue = [&](Value src) { - for (auto op = ops.rbegin(); op != ops.rend(); ++op) { - src = mlir::TypeSwitch(*op) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Default([](Operation *op) { - llvm_unreachable("unsupported op in generating "); - return nullptr; - }); - } - return src; - }; - - while (true) { - if (!dyn_cast(operand.getType())) { - return reconstructScalarValue(operand); - } else if (auto op = operand.getDefiningOp()) { - if (auto attr = dyn_cast(op.getValue())) { - if (!attr.isSplat()) { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load " - "produced by unsupported instruction"; - return nullptr; - } - auto elemValue = attr.getSplatValue(); - auto constOp = arith::ConstantOp::materialize( - rewriter, elemValue, attr.getElementType(), op.getLoc()); - return reconstructScalarValue(constOp.getResult()); - } - } else if (auto op = operand.getDefiningOp()) { - operand = op.getSrc(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load produced " - "by unsupported instruction"; - return nullptr; - } - } - return nullptr; -} - -memref::SubViewOp makeSubViewOp(Value src, - llvm::SmallVectorImpl &sizes, - const Location &loc, - ConversionPatternRewriter &rewriter) { - auto srcType = dyn_cast(src.getType()); - SmallVector offsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(srcType.getRank(), - rewriter.getIndexAttr(1)); - auto dstType = - memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return rewriter.create(loc, dyn_cast(dstType), - src, offsets, sizes, strides); -} - -void getShapeInfo(Value val, llvm::SmallVectorImpl &shapes, - ConversionPatternRewriter &rewriter) { - if (isa(val)) { - auto blockArg = dyn_cast(val); - auto blockOp = blockArg.getOwner()->getParentOp(); - if (isa(blockOp)) { - auto forOp = dyn_cast(blockOp); - auto operand = forOp.getTiedLoopInit(blockArg)->get(); - getShapeInfo(operand, shapes, rewriter); - } else { - emitError(val.getLoc()) - << "getShapeInfo() only support ReinterpretCastOp " - "and scf.for's block argument, but got : " - << val << "\n"; - } - return; - } - - if (isa(val.getType())) { - val = rewriter.getRemappedValue(val); - } - - if (!isa(val.getDefiningOp())) { - emitError(val.getLoc()) << "getShapeInfo() only support ReinterpretCastOp " - "and scf.for's block argument, but got : " - << val << "\n"; - return; - } - auto castOp = dyn_cast(val.getDefiningOp()); - auto tensorPtrAttr = castOp->getAttr("tensor_ptr_attr"); - if (tensorPtrAttr) { - shapes = castOp.getConstifiedMixedSizes(); - } else { - getShapeInfo(castOp.getSource(), shapes, rewriter); - } - return; -} - -SmallVector -getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, - Value adaptorPtr, const Location &loc, - ConversionPatternRewriter &rewriter) { - SmallVector parTensorShapes; - getShapeInfo(adaptorPtr, parTensorShapes, rewriter); - auto extractOp = - rewriter.create(loc, adaptorPtr); - - OpFoldResult baseOffset = extractOp.getConstifiedMixedOffset(); - SmallVector strides = extractOp.getConstifiedMixedStrides(); - - SmallVector boundarySizes = extractOp.getConstifiedMixedSizes(); - auto dims = boundarySizes.size(); - OpFoldResult currentStride = rewriter.getIndexAttr(1); - for (int i = dims - 1; i >= 0; i--) { - auto offset = divOpFoldResult(baseOffset, currentStride, loc, rewriter); - offset = remOpFoldResult(offset, parTensorShapes[i], loc, rewriter); - if (llvm::find(boundaryCheck, i) != boundaryCheck.end()) { - OpFoldResult subOfr = - subOpFoldResult(parTensorShapes[i], offset, loc, rewriter); - boundarySizes[i] = - minOpFoldResult(boundarySizes[i], subOfr, loc, rewriter); - } - currentStride = - mulOpFoldResult(currentStride, parTensorShapes[i], loc, rewriter); - } - return boundarySizes; -} - -SmallVector getBroadcastDims(RankedTensorType src, - RankedTensorType dst) { - SmallVector broadcastDims; - auto srcShape = src.getShape(); - auto dstShape = dst.getShape(); - - for (size_t i = 0; i < srcShape.size(); ++i) { - if (dstShape[i] != srcShape[i]) { - assert(srcShape[i] == 1 && - "Size of source broadcast dimension must be 1"); - broadcastDims.push_back(i); - } - } - assert(!broadcastDims.empty() && "Cannot identify broadcast dimension"); - return broadcastDims; -} - -// Dimensions of collapesd tensor is all unbroadcast dims -SmallVector getUnbroadcastDims(RankedTensorType src, - RankedTensorType dst) { - SmallVector unbroadcastDims; - auto srcShape = src.getShape(); - auto dstShape = dst.getShape(); - - for (size_t i = 0; i < srcShape.size(); ++i) { - if (dstShape[i] == srcShape[i]) { - unbroadcastDims.emplace_back(srcShape[i]); - } - } - return unbroadcastDims; -} - -} // namespace ConverterUtils - -namespace triton { - -mlir::Operation * -findFirstMatchingOperandDef(mlir::Operation *rootOp, - const std::function &condFn) { - LLVM_DEBUG(llvm::dbgs() << "[findFirstMatchingOperandDef] Current op: " - << *rootOp << "\n"); - mlir::Value lhs = nullptr; - mlir::Value rhs = nullptr; - if (auto op = dyn_cast(rootOp)) { - lhs = op.getPtr(); - rhs = op.getOffset(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getSrc(); - } else if (auto op = dyn_cast(rootOp)) { - } else { - rootOp->emitRemark("Backtracing encounters unsupported Operation"); - return nullptr; - } - // Backtrace operands - if (!lhs) { - return nullptr; - } - auto lhsDef = lhs.getDefiningOp(); - mlir::Operation *targetOp; - if (lhsDef) { - if (condFn(lhsDef)) { - targetOp = lhsDef; - } else { - targetOp = findFirstMatchingOperandDef(lhsDef, condFn); - } - if (targetOp) { - return targetOp; - } - } - if (!rhs) { - return nullptr; - } - auto rhsDef = rhs.getDefiningOp(); - if (rhsDef) { - if (condFn(rhsDef)) { - targetOp = rhsDef; - } else { - targetOp = findFirstMatchingOperandDef(rhsDef, condFn); - } - if (targetOp) { - return targetOp; - } - } - return nullptr; -} - -void traverseBackwardUpdateOperandChainIf( - Operation *op, std::function conditionFn, - std::function actionFn, - OpBuilder &builder) { - - if (!op) - return; - - if (conditionFn(op)) { - actionFn(builder, op); - } - - for (Value operand : op->getOperands()) { - // TODO: handle BlockArgument - if (Operation *defOp = operand.getDefiningOp()) { - traverseBackwardUpdateOperandChainIf(defOp, conditionFn, actionFn, - builder); - } - } -} - -// Note: rootOp will also be processed. -void traverseBackwardUpdateOperandChainIf( - Operation *rootOp, std::function conditionFn, - std::function actionFn) { - - OpBuilder builder(rootOp->getContext()); - - traverseBackwardUpdateOperandChainIf(rootOp, conditionFn, actionFn, builder); -} - -void traverseForwardUpdateUserChainIf( - Operation *op, std::function conditionFn, - std::function stopFn, - std::function actionFn, OpBuilder &builder, - llvm::SmallPtrSet &stopOps) { - - if (!op) { - return; - } - - if (stopFn(op)) { - stopOps.insert(op); - return; - } - - if (conditionFn(op)) { - actionFn(builder, op); - } - - for (auto res : op->getResults()) { - for (auto userOp : res.getUsers()) { - traverseForwardUpdateUserChainIf(userOp, conditionFn, stopFn, actionFn, - builder, stopOps); - } - } -} - -// Note: rootOp will also be processed. -void traverseForwardUpdateUserChainIf( - Operation *rootOp, std::function conditionFn, - std::function stopFn, - std::function actionFn, - llvm::SmallPtrSet &stopOps) { - - OpBuilder builder(rootOp->getContext()); - - traverseForwardUpdateUserChainIf(rootOp, conditionFn, stopFn, actionFn, - builder, stopOps); -} - -bool isMetaUse(Operation *op) { return op->hasAttr("MetaUse"); } - -bool isMixUse(Operation *op) { return op->hasAttr("MixUse"); } - -IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op) { - auto ty = IndirectLoadInterfaceOpType::Undefined; - if (isMetaUse(op)) { - if (isa(op)) { - ty = IndirectLoadInterfaceOpType::Load; - } else if (isa(op)) { - ty = IndirectLoadInterfaceOpType::Calc; - } - } - return ty; -} - -bool opIsIndirectLoad(Operation *op) { - auto opType = getIndirectLoadInterfaceOpType(op); - return opType == IndirectLoadInterfaceOpType::Load; -} - -bool opIsIndirectCalc(Operation *op) { - auto opType = getIndirectLoadInterfaceOpType(op); - return opType == IndirectLoadInterfaceOpType::Calc; -} - -scf::ForOp createNestedLoops( - OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, - ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, - ValueRange initArgs, - function_ref &, ValueRange)> - bodyBuilder) { - - if (currentDim >= totalDims) { - bodyBuilder(builder, loc, ivs, initArgs); - return nullptr; - } - - auto loop = builder.create( - loc, LBs[currentDim], UBs[currentDim], steps[currentDim], initArgs, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, - ValueRange iterArgs) { - ivs.push_back(iv); - auto innerLoop = createNestedLoops(nestedBuilder, nestedLoc, - currentDim + 1, totalDims, LBs, UBs, - steps, ivs, iterArgs, bodyBuilder); - if (innerLoop) { - nestedBuilder.create(loc, innerLoop.getResults()); - } - }); - - return loop; -} - -ModuleOp getModuleOpFromOperation(Operation *op) { - Operation *parent = op; - while (parent != nullptr && !isa(parent)) { - parent = parent->getParentOp(); // 向上查找 - } - return cast(parent); // 如果没找到会抛出异常 -} - -} // namespace triton - -std::optional makeIntAttr(const OpFoldResult &ofr) { - if (isa(ofr) && isa(ofr.get())) - return dyn_cast(ofr.get()).getInt(); - return std::nullopt; -} - -bool hasConstantZero(const OpFoldResult &ofr) { - auto intAttr = makeIntAttr(ofr); - if (intAttr.has_value()) - return !intAttr.value(); - - auto val = dyn_cast(ofr); - assert(val && "Provided ofr must can be cast to Value"); - - auto ConstOp = val.getDefiningOp(); - if (!ConstOp) - return false; - - intAttr = makeIntAttr(ConstOp.getValue()); - return intAttr.has_value() && !intAttr.value(); -} - -Value opFoldResultToIndex(const OpFoldResult &ofr, const Location &loc, - OpBuilder &b) { - if (auto val = dyn_cast(ofr)) { - assert(val.getType().isIndex() && "Provided ofr shoule be type of Index"); - return val; - } - - auto intAttr = makeIntAttr(ofr); - if (intAttr.has_value()) { - return b.create(loc, b.getIndexAttr(intAttr.value())); - } - llvm_unreachable("Unexpected OpFoldResult state"); - return nullptr; -} - -SmallVector opFoldResultToIndex(ArrayRef ofrs, - const Location &loc, OpBuilder &b) { - return llvm::map_to_vector<4>(ofrs, [&](OpFoldResult ofr) -> Value { - return opFoldResultToIndex(ofr, loc, b); - }); -} - -Value createConstIntOp(const Location &loc, OpBuilder &b, int64_t value) { - return b.create(loc, b.getIndexAttr(value)).getResult(); -} - -// TODO: imply these function below -OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - - if (!lhsInt && rhsInt && rhsInt.value() == 0) { - return lhs; - } - if (!rhsInt && lhsInt && lhsInt.value() == 0) { - return rhs; - } - - if (lhsInt && rhsInt) { - return b.getIndexAttr(lhsInt.value() + rhsInt.value()); - } - - auto lhsValue = dyn_cast(lhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } else { - assert(isa(lhsValue.getType())); - } - - auto rhsValue = dyn_cast(rhs); - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } else { - assert(isa(rhsValue.getType())); - } - - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - - if (!lhsInt && rhsInt && rhsInt.value() == 0) { - return lhs; - } - - if (lhsInt && rhsInt) { - return b.getIndexAttr(lhsInt.value() - rhsInt.value()); - } - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - - if (lhsInt) { - if (lhsInt.value() == 0) { - return lhs; - } - if (lhsInt.value() == 1) { - return rhs; - } - } - if (rhsInt) { - if (rhsInt.value() == 0) { - return rhs; - } - if (rhsInt.value() == 1) { - return lhs; - } - } - - if (lhsInt && rhsInt) { - return b.getIndexAttr(lhsInt.value() * rhsInt.value()); - } - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const Value &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsConstFlag = false; - - auto rhsConstInt = std::numeric_limits::max(); - auto rhsOp = rhs.getDefiningOp(); - if (rhsOp) { - rhsConstFlag = true; - rhsConstInt = dyn_cast(rhsOp.getValue()).getInt(); - } - - if (lhsInt && rhsConstFlag) { - return b.getIndexAttr(lhsInt.value() * rhsConstInt); - } - - if (lhsInt) { - if (lhsInt.value() == 0) { - return lhs; - } - if (lhsInt.value() == 1) { - return rhs; - } - } - if (rhsConstFlag) { - if (rhsConstInt == 0) { - return rhsOp.getResult(); - } - if (rhsConstInt == 1) { - return lhs; - } - } - - if (lhsInt && !rhsConstFlag) { - auto lhsValue = createConstIntOp(loc, b, lhsInt.value()); - return b.create(loc, lhsValue, rhs).getResult(); - } - assert(!lhsInt); - return b.create(loc, lhs.get(), rhs).getResult(); -} - -OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - if (lhsInt) { - if (lhsInt.value() == 0) { - return lhs; - } - } - if (rhsInt) { - if (rhsInt.value() == 0) { - emitError(loc) << "cannot div 0!"; - return OpFoldResult(); - } - if (rhsInt.value() == 1) { - return lhs; - } - } - - if (lhsInt && rhsInt) { - return b.getIndexAttr(lhsInt.value() / rhsInt.value()); - } - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - if (lhsInt && lhsInt.value() == 0) { - return lhs; - } - if (lhsInt && rhsInt) { - return b.getIndexAttr(lhsInt.value() % rhsInt.value()); - } - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - if (lhsInt && rhsInt) { - return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); - } - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = makeIntAttr(lhs); - auto rhsInt = makeIntAttr(rhs); - if (lhsInt && rhsInt) { - return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); - } - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) { - lhsValue = createConstIntOp(loc, b, lhsInt.value()); - } - if (rhsInt) { - rhsValue = createConstIntOp(loc, b, rhsInt.value()); - } - return b.create(loc, lhsValue, rhsValue).getResult(); -} - -LogicalResult -addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, - linalg::ReduceOp reduceOp) { - // To verify whether the operation of the reduceOp is ReduceWithIndex - // TODO: maybe a better way of judging? - auto ctx = reduceOp.getContext(); - Block &body = reduceOp.getCombiner().front(); - auto yieldOp = dyn_cast(body.getTerminator()); - - auto yieldValue = yieldOp.getValues(); - if (yieldValue.size() == 0) { - return failure(); - } - - auto opIter = reduceOp.getBody()->without_terminator().begin(); - auto cmpMaskOp = dyn_cast(*opIter); - const StringRef reduceRef = "reduce_mode"; - if (cmpMaskOp) { - if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OGT) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); - } else if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OLT) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); - } - } - - auto cmpMaskIOp = dyn_cast(*opIter); - if (cmpMaskIOp) { - if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::sgt) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); - } else if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::slt) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); - } - } - - return success(); -} - -} // namespace mlir diff --git a/third_party/ascend/triton-adapter/tools/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/CMakeLists.txt deleted file mode 100644 index 628169551..000000000 --- a/third_party/ascend/triton-adapter/tools/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(triton-adapter-opt) diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt deleted file mode 100644 index 37fea14db..000000000 --- a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) - -add_llvm_executable(triton-adapter-opt triton-adapter-opt.cpp PARTIAL_SOURCES_INTENDED) - -# TODO: what's this? -llvm_update_compile_flags(triton-adapter-opt) -target_link_libraries(triton-adapter-opt PRIVATE TritonToLinalg - TritonTransforms - ${dialect_libs} - ${conversion_libs} - TritonGPUTransforms - MLIROptLib - MLIRPass - MLIRTransforms -) - -mlir_check_all_link_libraries(triton-adapter-opt) diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp deleted file mode 100644 index ba9c185e2..000000000 --- a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "../../include/TritonToLinalg/Passes.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -int main(int argc, char **argv) { - mlir::DialectRegistry registry; - mlir::triton::registerTritonToLinalgPass(); - - registry.insert< - mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, - mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, - mlir::linalg::LinalgDialect, mlir::func::FuncDialect, - mlir::tensor::TensorDialect, mlir::memref::MemRefDialect, - mlir::bufferization::BufferizationDialect, mlir::gpu::GPUDialect>(); - - return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "Triton-Adapter test driver\n", registry)); -} diff --git a/third_party/ascend/triton-adapter/triton_adapter.cc b/third_party/ascend/triton-adapter/triton_adapter.cc deleted file mode 100644 index 7fa5e82a5..000000000 --- a/third_party/ascend/triton-adapter/triton_adapter.cc +++ /dev/null @@ -1,6 +0,0 @@ -#include - -namespace py = pybind11; - -// compilation goes to triton-adapter-opt, do nothing here -void init_triton_triton_adapter(py::module &&m) {} diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp deleted file mode 100644 index bb3cc70bd..000000000 --- a/third_party/ascend/triton_ascend.cpp +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - */ -#define PY_SSIZE_T_CLEAN -#include -namespace py = pybind11; - -// register huawei passes to triton -void init_triton_ascend(py::module &&m) { - // currently no extra modules needed to plug-in libtriton.so -} diff --git a/third_party/ascend/triton_patch/include/CMakeLists.txt b/third_party/ascend/triton_patch/include/CMakeLists.txt deleted file mode 100644 index 109c292fe..000000000 --- a/third_party/ascend/triton_patch/include/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(triton) diff --git a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt deleted file mode 100644 index 0ca0f41c5..000000000 --- a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt deleted file mode 100644 index 5e601271e..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt deleted file mode 100644 index f33061b2d..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt deleted file mode 100644 index 9984e2e01..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ /dev/null @@ -1,34 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -# file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") -# string(REPLACE "triton_patch" "triton" triton_rel_dir "${patch_rel_dir}") -# set(triton_abs_dir "${CMAKE_SOURCE_DIR}/${triton_rel_dir}") -# message(STATUS "triton_abs_dir: ${triton_abs_dir}") -# message(${triton_abs_dir}) -set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") -message(${triton_abs_dir}) -set(LLVM_TARGET_DEFINITIONS TritonOps.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) -# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) -mlir_tablegen(Types.h.inc -gen-typedef-decls) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) -mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) -mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) - -add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td deleted file mode 100644 index b59bc7c8f..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef TRITON_ATTR_DEFS -#define TRITON_ATTR_DEFS - -include "mlir/IR/EnumAttr.td" - -// Attributes for LoadOp and StoreOp -def TT_CacheModifierAttr : I32EnumAttr< - "CacheModifier", "", - [ - I32EnumAttrCase<"NONE", 1, "none">, - I32EnumAttrCase<"CA", 2, "ca">, - I32EnumAttrCase<"CG", 3, "cg">, - I32EnumAttrCase<"WB", 4, "wb">, - I32EnumAttrCase<"CS", 5, "cs">, - I32EnumAttrCase<"WT", 6, "wt">, - I32EnumAttrCase<"CV", 7, "cv">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -def TT_MemSemanticAttr : I32EnumAttr< - "MemSemantic", "", - [ - I32EnumAttrCase<"RELAXED", 1, "relaxed">, - I32EnumAttrCase<"ACQUIRE", 2, "acquire">, - I32EnumAttrCase<"RELEASE", 3, "release">, - I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -def TT_EvictionPolicyAttr : I32EnumAttr< - "EvictionPolicy", "", - [ - I32EnumAttrCase<"NORMAL", 1, "evict_normal">, - I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, - I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> - ]> { - let cppNamespace = "::mlir::triton"; -} - -def TT_PaddingOptionAttr : I32EnumAttr< - "PaddingOption", "", - [ - I32EnumAttrCase<"PAD_ZERO", 1, "zero">, - // We can not set the string value to "NAN" because it is a keyword in C++ - I32EnumAttrCase<"PAD_NAN", 2, "nan"> - ]> { - let cppNamespace = "::mlir::triton"; -} - -// atomic -def TT_AtomicRMWAttr : I32EnumAttr< - "RMWOp", "", - [ - I32EnumAttrCase<"AND", 1, "and">, - I32EnumAttrCase<"OR", 2, "or">, - I32EnumAttrCase<"XOR", 3, "xor">, - I32EnumAttrCase<"ADD", 4, "add">, - I32EnumAttrCase<"FADD", 5, "fadd">, - I32EnumAttrCase<"MAX", 6, "max">, - I32EnumAttrCase<"MIN", 7, "min">, - I32EnumAttrCase<"UMAX", 8, "umax">, - I32EnumAttrCase<"UMIN", 9, "umin">, - I32EnumAttrCase<"XCHG", 10, "exch"> - ]> { - let cppNamespace = "::mlir::triton"; -} - -def TT_MemSyncScopeAttr : I32EnumAttr< - "MemSyncScope", "", - [ - I32EnumAttrCase<"GPU", 1, "gpu">, - I32EnumAttrCase<"CTA", 2, "cta">, - I32EnumAttrCase<"SYSTEM", 3, "sys">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -// Program ID dimensions. -def TT_ProgramDim : I32EnumAttr< - "ProgramIDDim", "", - [ - I32EnumAttrCase<"X", 0, "x">, - I32EnumAttrCase<"Y", 1, "y">, - I32EnumAttrCase<"Z", 2, "z">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -// Rounding mode. -def TT_RoundingModeAttr : I32EnumAttr< - "RoundingMode", "", - [ - I32EnumAttrCase<"RTZ", 0, "rtz">, - I32EnumAttrCase<"RTNE", 1, "rtne">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -// PropagateNan. -def TT_PropagateNanAttr : I32EnumAttr< - "PropagateNan", "", - [ - I32EnumAttrCase<"NONE", 0, "none">, - I32EnumAttrCase<"ALL", 0xFFFF, "all">, - ]> { - let cppNamespace = "::mlir::triton"; -} - -// InputPrecision -def TT_InputPrecisionAttr : I32EnumAttr< - "InputPrecision", "", - [ - I32EnumAttrCase<"TF32", 0, "tf32">, - I32EnumAttrCase<"TF32x3", 1, "tf32x3">, - I32EnumAttrCase<"IEEE", 2, "ieee">, - I32EnumAttrCase<"HF32", 3, "hf32">, - ]>{ - let cppNamespace = "::mlir::triton"; -} - -// Type for F8F6F4 kind of floats. -def TT_F8F6F4TypeAttr : I32EnumAttr< - "F8F6F4Type", "", - [ - I32EnumAttrCase<"E4M3", 0, "e4m3">, - I32EnumAttrCase<"E5M2", 1, "e5m2">, - I32EnumAttrCase<"E2M3", 2, "e2m3">, - I32EnumAttrCase<"E3M2", 3, "e3m2">, - I32EnumAttrCase<"E2M1", 4, "e2m1"> - - ]>{ - let cppNamespace = "::mlir::triton"; -} - -#endif diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td deleted file mode 100644 index 9bca3da18..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td +++ /dev/null @@ -1,1286 +0,0 @@ -#ifndef TRITON_OPS -#define TRITON_OPS - -include "triton/Dialect/Triton/IR/TritonDialect.td" -include "triton/Dialect/Triton/IR/TritonTypes.td" -include "triton/Dialect/Triton/IR/TritonAttrDefs.td" -include "triton/Dialect/Triton/IR/TritonInterfaces.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface -include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface -include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface -include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface -include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface -include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface -include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" - - -// -// Interfaces -// -def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; - -// -// Op Base -// -class TT_Op traits = []> : - Op { -} - -// -// Cast Ops -// -// Use cast ops in arith: -// bitcast -// fptoui, fptosi, uitofp, sitofp, -// extf, tructf, -// extui, extsi, tructi -def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { - let summary = "Cast int64 to pointer"; - - let arguments = (ins TT_I64Like:$src); - - let results = (outs TT_PtrLike:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} - -def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { - let summary = "Cast pointer to int64"; - - let arguments = (ins TT_PtrLike:$src); - - let results = (outs TT_I64Like:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} - -// arith.bitcast doesn't support pointers -def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { - let summary = "Cast between types of the same bitwidth"; - - let arguments = (ins TT_Type:$src); - - let results = (outs TT_Type:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; - - // TODO: Add verifier -} - -def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { - let summary = "Floating point casting for custom types"; - - let description = [{ - Floating point casting for custom types (F8), and non-default rounding modes. - - F8 <-> FP16, BF16, FP32, FP64 - }]; - - let arguments = ( - ins TT_FloatTensor:$src, - OptionalAttr:$rounding - ); - - let results = (outs TT_FloatTensor:$result); - - let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; - - let hasVerifier = 1; -} - -// -// Arithmetic Ops -// - -def TT_ClampFOp : TT_Op<"clampf", [Elementwise, - SameOperandsAndResultType, - Pure]> { - let summary = "Clamp operation for floating point types"; - - let description = [{ - Clamp operation for floating point types. - - The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. - }]; - - let arguments = ( - ins - TT_FloatLike:$x, - TT_FloatLike:$min, - TT_FloatLike:$max, - TT_PropagateNanAttr:$propagateNan - ); - - let results = (outs TT_FloatLike:$result); - - // List $propagateNan explicitly rather than relying on attr-dict to pick it - // up, because if it's inside attr-dict, its value will be printed as a - // number rather than as a meaningful string. - let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; -} - -// -// Math Ops -// - -def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, - SameOperandsAndResultType, - Pure]> { - let summary = "Precise sqrt for floating point types"; - - let description = [{ - Precise sqrt for floating point types. - }]; - - let arguments = (ins TT_FloatLike:$x); - - let results = (outs TT_FloatLike:$result); - - let assemblyFormat = "$x attr-dict `:` type($x)"; -} - -def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, - SameOperandsAndResultType, - Pure]> { - let summary = "Precise div for floating point types"; - - let description = [{ - Precise div for floating point types. - }]; - - let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); - - let results = (outs TT_FloatLike:$result); - - let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; -} - -def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, - SameOperandsAndResultType, - Pure]> { - let summary = "Most significant N bits of the 2N-bit product of two integers"; - - let description = [{ - Most significant N bits of the 2N-bit product of two integers. - }]; - - let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); - - let results = (outs TT_IntLike:$result); - - let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; -} - -// -// Pointer Arith Ops -// -def TT_AddPtrOp : TT_Op<"addptr", - [Pure, - Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">]> { - let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); - - let results = (outs TT_PtrLike:$result); - - let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; -} - -def TT_AdvanceOp : TT_Op<"advance", - [Pure, - TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">]> { - let summary = "Advance a tensor pointer by offsets"; - - let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); - - let results = (outs TT_TensorPtr:$result); - - let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; - - let hasFolder = 1; -} - -// -// Load/Store Ops -// -def TT_LoadOp : TT_Op<"load", [ - SameLoadStoreOperandsAndResultShape, - SameLoadStoreOperandsAndResultEncoding, - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, - TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", - "($_op.getOperands().size() <= 1) || std::equal_to<>()">, - TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", - "($_op.getOperands().size() <= 2) || std::equal_to<>()"> -]> { - let summary = "Load from a tensor of pointers or from a tensor pointer"; - - let arguments = ( - ins - AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, - Optional:$mask, - Optional:$other, - - DefaultValuedAttr{}">:$boundaryCheck, - OptionalAttr:$padding, - DefaultValuedAttr:$cache, - DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile - ); - - let results = (outs TT_Type:$result); - - let builders = [ - // A tensor of pointers or a pointer to a scalar - OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, - // A tensor pointer with boundary check and padding - OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, - "std::optional":$padding, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, - // A tensor of pointers or a pointer to a scalar with mask - OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, - // A tensor of pointers or a pointer to a scalar with mask and other - OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, - // A utility function to build the operation with all attributes - OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, - "ArrayRef":$boundaryCheck, - "std::optional":$padding, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)> - ]; - - // Specify `cacheModifier` and `evictionPolicy` explicitly in the - // assemblyFormat instead of as part of attr-dict so that they get printed - // as strings rather than opaque integers. - // - // Note there's no comma between `other` and `cacheModifier` and between - // `cacheModifier` and `evictionPolicy`. This is due to an apparent - // limitation in the MLIR custom-format parser. In oilist, the initial - // keywords of each clause have to be unique, so they can't be `,`. - // - // Even if we gave up on order-independence and used vanilla optional - // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will - // not match the string ", bar = 0" because after the initial comma (first - // token of the first optional clause) we expect to see "foo". - let assemblyFormat = [{ - $ptr (`,` $mask^)? (`,` $other^)? - oilist( - `cacheModifier` `=` $cache | - `evictionPolicy` `=` $evict - ) - attr-dict `:` type($ptr) - }]; - - let hasCanonicalizer = 1; -} - -def TT_StoreOp : TT_Op<"store", [ - SameLoadStoreOperandsShape, - SameLoadStoreOperandsEncoding, - MemoryEffects<[MemWrite]>, - TypesMatchWith<"value type matches ptr type", "ptr", "value", - "getPointeeType($_self)">, - TypesMatchWith<"mask type matches ptr type", "ptr", "mask", - "getI1SameShape(getPointeeType($_self))", - "($_op.getOperands().size() <= 2) || std::equal_to<>()"> -]> { - let summary = "Store by a tensor of pointers or by a tensor pointer"; - - let arguments = ( - ins - AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, - TT_Type:$value, - Optional:$mask, - DefaultValuedAttr{}">:$boundaryCheck, - DefaultValuedAttr:$cache, - DefaultValuedAttr:$evict - ); - - let builders = [ - // A tensor of pointers or a pointer to a scalar - OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, - // A tensor of pointers or a pointer to a scalar with mask - OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict)>, - // A tensor pointer with boundary check - OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict)> - ]; - - // Specify cacheModifier and evictionPolicy explicitly, instead of leaving - // them in attr-dict, because this way their values get printed as strings, - // rather than as opaque integers. - // - // Note there are no commas between mask, cacheModifier, and evictionPolicy, - // due to limitations in MLIR's asm parser. - let assemblyFormat = [{ - $ptr `,` $value (`,` $mask^)? - oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) - attr-dict `:` type($ptr) - }]; - - let hasCanonicalizer = 1; -} - -// -// Atomic Ops -// -def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - MemoryEffects<[MemRead]>, - MemoryEffects<[MemWrite]>, - TypesMatchWith<"ptr type matches value type", "val", "ptr", - "getPointerTypeSameShape($_self)">, - TypesMatchWith<"mask type matches value type", - "val", "mask", "getI1SameShape($_self)", - "($_op.getOperands().size() <= 2) || std::equal_to<>()"> -]> { - let summary = "atomic rmw"; - - let description = [{ - load data at $ptr, do $rmw_op with $val, and store result to $ptr. - - return old value at $ptr - }]; - - let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, - TT_Type:$val, Optional:$mask, - TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); - - let results = (outs TT_Type:$result); - - // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on - // attr-dict so they're printed as strings rather than opaque integers. - let assemblyFormat = [{ - $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` - functional-type(operands, $result) - }]; -} - -def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, - MemoryEffects<[MemWrite]>, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding]> { - let summary = "atomic cas"; - - let description = [{ - compare $cmp with data $old at location $ptr, - - if $old == $cmp, store $val to $ptr, - - else store $old to $ptr, - - return $old - }]; - - let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, - TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); - - let results = (outs TT_Type:$result); - - // Explicitly list $sem and $scope rather than relying on attr-dict so - // they're printed as strings rather than opaque integers. - let assemblyFormat = [{ - $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` - functional-type(operands, $result) - }]; -} - -// -// Shape Manipulation Ops -// -def TT_SplatOp : TT_Op<"splat", [Pure, - SameOperandsAndResultElementType, - SameOperandsAndResultEncoding]> { - let summary = "splat"; - - let arguments = (ins TT_Type:$src); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; - - let hasFolder = 1; -} - -def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, - DeclareOpInterfaceMethods, - SameOperandsAndResultElementType]> { - let summary = "expand_dims"; - - let arguments = (ins TT_Tensor:$src, I32Attr:$axis); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; - - let hasCanonicalizeMethod = 1; - let hasFolder = 1; -} - -def TT_ReshapeOp : TT_Op<"reshape", [Pure, - SameOperandsAndResultElementType]> { - let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; - let description = [{ - reinterpret a tensor to a different shape. - - If allow_reorder is set the compiler is free to change the order of - elements to generate more efficient code. - - If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. - The compiler is still free to change it for better performance. - }]; - let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); - let results = (outs TT_Tensor:$result); - let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; - let hasCanonicalizeMethod = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - -def TT_BroadcastOp : TT_Op<"broadcast", [Pure, - SameOperandsAndResultElementType, - SameOperandsAndResultEncoding]> { - let summary = "broadcast a tensor"; - - let description = [{ - For a given tensor, broadcast changes one or more dimensions with size 1 - to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot - change the size of a non-1 dimension. - }]; - - let arguments = (ins TT_Tensor:$src); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; - - let hasCanonicalizeMethod = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - -// cat is not `pure` because it may reorder elements -def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, - SameTypeOperands, - SameOperandsAndResultElementType]> { - let summary = "concatenate 2 tensors"; - - let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; -} - -def TT_JoinOp : TT_Op<"join", [ - NoMemoryEffect, SameTypeOperands, - DeclareOpInterfaceMethods, -]> { - let summary = "join two tensors along a new, minor dimension"; - let description = [{ - For example, if the two input tensors are 4x8xf32, returns a tensor of - shape 4x8x2xf32. - - Because Triton tensors always have a power-of-two number of elements, - the two input tensors must have the same shape. - }]; - - let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); - let results = (outs TT_Tensor:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; -} - -def TT_SplitOp : TT_Op<"split", [ - NoMemoryEffect, - DeclareOpInterfaceMethods, - TypesMatchWith<"outLHS and outRHS types match", - "outLHS", "outRHS", "$_self">, -]> { - let summary = "splits a tensor into two, along its last dimension"; - let description = [{ - The input must be a tensor whose last dimension has size 2. Returns two - tensors, src[..., 0] and src[..., 1]. - - For example, if the input shape is 4x8x2xf32, returns two tensors of - shape 4x8xf32. - }]; - - let arguments = (ins TT_Tensor:$src); - let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; -} - -def TT_TransOp : TT_Op<"trans", [Pure, - DeclareOpInterfaceMethods, - SameOperandsAndResultElementType]> { - - let summary = "rearrange the dimensions of a tensor"; - let description = [{ - For example, given a tensor x with shape [1,2,4], transpose(x) with - order=[2,0,1] rearranges the tensor to have shape [4,1,2]. - - Although this op is called "trans", it implements both tl.trans() and - tl.permute(). ("permute" might be a better name, but it's called "trans" - because originally it only supported 2D tensors.) - - ## Implementation note on encodings: - - In the TritonGPU dialect (and probably others), an encoding is chosen for - this op's output so it's a nop from the perspective of code generation. - - For example, suppose tensor x has an encoding such that GPU thread [i,j,k] - has a register containing element [i,j,k] of the tensor. Now we transpose - x with order [2,1,0], i.e. we reverse the order of its dimensions. In - TritonGPU, we will choose a layout for the output of the transpose so that - GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the - same element it had before! All we've done is "rename" the element that - thread [i,j,k] has. - - The "real" transpose -- i.e. moving data between GPU threads -- occurs in - convertLayout ops that appear before and/or after the operation. - - We do this so that you can chain multiple data-movement ops (e.g. - transpose+reshape+concat) without going to shared memory after each one. - }]; - - let arguments = ( - ins TT_TensorOrMemDesc:$src, - DenseI32ArrayAttr:$order - ); - - let results = (outs TT_TensorOrMemDesc:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; - - let hasFolder = 1; - let hasVerifier = 1; -} - -// -// SPMD Ops -// -def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { - let arguments = (ins TT_ProgramDim:$axis); - - let results = (outs I32:$result); - - let assemblyFormat = "$axis attr-dict `:` type($result)"; - - let builders = [ - OpBuilder<(ins "int":$axis), [{ - build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); - }]> - ]; - - let extraClassDeclaration = [{ - int32_t getAxisAsInt() { - return static_cast(getAxis()); - } - }]; -} - -def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { - let arguments = (ins TT_ProgramDim:$axis); - - let results = (outs I32:$result); - - let assemblyFormat = "$axis attr-dict `:` type($result)"; - let builders = [ - OpBuilder<(ins "int":$axis), [{ - build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); - }]> - ]; - - let extraClassDeclaration = [{ - int32_t getAxisAsInt() { - return static_cast(getAxis()); - } - }]; -} - -// -// Dot Op -// -def TT_DotOp : TT_Op<"dot", [Pure, - DeclareOpInterfaceMethods, - DotLike, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { - let summary = "dot"; - - let description = [{ - $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC - when the inputs are f32. It can be one of: tf32, tf32x3, ieee. - tf32: use TC with tf32 ops. - tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp - ieee: don't use TC, implement dot in software. - If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. - }]; - - let arguments = ( - ins - TT_FpIntTensor:$a, - TT_FpIntTensor:$b, - TT_FpIntTensor:$c, - DefaultValuedAttr:$inputPrecision, - DefaultValuedAttr:$maxNumImpreciseAcc - ); - - let results = (outs TT_FpIntTensor:$d); - - // attr-dict prints enums as integers. To get inputPrecision printed as a - // string, we need to specify it explicitly. - let assemblyFormat = [{ - $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` - type($a) `*` type($b) `->` type($d) - }]; - let hasVerifier = 1; -} - - -// -// DotScaled Op -// -def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, - DotLike, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { - let summary = "dot_scaled"; - - let description = [{ - $d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c. - Where scale(x, s) is a function that applies the scale per block following microscaling spec. - }]; - - let arguments = ( - ins - // inputs are integer types as they are packed types and we currently - // don't have a representation for those. - TT_IntTensor:$lhs, - TT_IntTensor:$rhs, - TT_FloatTensor:$c, - TT_IntTensor:$lhs_scale, - Optional:$rhs_scale, - TT_F8F6F4TypeAttr:$lhs_type, - TT_F8F6F4TypeAttr:$rhs_type - ); - - let results = (outs TT_FloatTensor:$d); - - // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file - let assemblyFormat = [{ - $lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict - `:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) - }]; -} - -// -// Reduce Op -// -def TT_ReduceOp: TT_Op<"reduce", - [Pure, - SameOperandsShape, - SameOperandsEncoding, - SingleBlock, - DeclareOpInterfaceMethods]> { - let summary = "Reduction using generic combination algorithm"; - let arguments = (ins Variadic:$srcs, I32Attr:$axis); - let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$combineOp); - let builders = [ - OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, - ]; - let hasVerifier = 1; - let hasRegionVerifier = 1; - let extraClassDeclaration = [{ - llvm::SmallVector getInputTypes(); - llvm::SmallVector getElementTypes(); - unsigned getNumOperands(); - }]; -} - -def TT_ReduceReturnOp: TT_Op<"reduce.return", - [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { - let summary = "terminator for reduce operator"; - let arguments = (ins Variadic:$result); - let assemblyFormat = "$result attr-dict `:` type($result)"; -} - -// -// Scan Op -// -def TT_ScanOp: TT_Op<"scan", - [Pure, - SameOperandsAndResultEncoding, - SameOperandsAndResultShape, - SingleBlock, - DeclareOpInterfaceMethods]> { - let summary = "Associative scan using generic combination algorithm"; - let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); - let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$combineOp); - let builders = [ - OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, - ]; - let hasVerifier = 1; - let hasRegionVerifier = 1; - let extraClassDeclaration = [{ - llvm::SmallVector getInputTypes(); - llvm::SmallVector getElementTypes(); - unsigned getNumOperands(); - }]; -} - -def TT_ScanReturnOp: TT_Op<"scan.return", - [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { - let summary = "terminator for scan operator"; - let arguments = (ins Variadic:$result); - let assemblyFormat = "$result attr-dict `:` type($result)"; -} - - -// -// External Elementwise op -// -def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, - SameOperandsAndResultEncoding, - SameVariadicOperandSize, - DeclareOpInterfaceMethods, - ConditionallySpeculatable]> { - - let description = [{ - call an external function $symbol implemented in $libpath/$libname with $args - return $libpath/$libname:$symbol($args...) - }]; - - let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); - - let results = (outs TT_Type:$result); - - let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; - - let extraClassDeclaration = [{ - // Interface method for ConditionallySpeculatable. - Speculation::Speculatability getSpeculatability(); - }]; - -} - -// -// Make Range Op -// -def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { - let summary = "make range"; - - let description = [{ - Returns an 1D int32 tensor. - - Values span from $start to $end (exclusive), with step = 1 - }]; - - // WARNING: MLIR generates getStart()/getEnd() functions which return - // uint32_t, even though these arguments are to be interpreted as *signed* - // int32 values. If this matters, use get{Start,End}Attr().getInt(), which - // return int64_t. - let arguments = (ins I32Attr:$start, I32Attr:$end); - - let results = (outs TT_IntTensor:$result); - - let assemblyFormat = "attr-dict `:` type($result)"; - - let hasFolder = 1; - let hasVerifier = 1; -} - -// -// ElementwiseInlineAsm Op -// -def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ - Elementwise, - SameOperandsAndResultEncoding, - DeclareOpInterfaceMethods -]> { - let summary = "inline assembly applying an elementwise operation to a group of packed elements."; - let description = [{ - Runs an inline asm block to generate one or more tensors. - - The asm block is given `packed_element` elements at a time. Exactly which - elems it receives is unspecified. - }]; - - let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); - let results = (outs Variadic:$result); - - let assemblyFormat = [{ - $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) - }]; - - let hasVerifier = 1; -} - -// -// Histogram Op -// -def TT_HistogramOp : TT_Op<"histogram", [Pure]> { - let summary = "return a histgram of the inputs."; - let description = [{ - Return the histogram of the input tensor. The number of bins is equal to - the dimension of the output tensor. Each bins has a width of 1 and bins - start at 0. - }]; - - let arguments = (ins TT_IntTensor:$src); - let results = (outs TT_IntTensor:$result); - - let assemblyFormat = [{ - $src attr-dict `:` type($src) `->` type($result) - }]; -} - -// -// Gather Op -// -def TT_GatherOp : TT_Op<"gather", [Pure, - DeclareOpInterfaceMethods]> { - let summary = "local gather operation"; - let description = [{ - Gather elements from the input tensor using the indices tensor along a - single specified axis. The output tensor has the same shape as the indices - tensor. The input and indices tensors must have the same number of - dimension, and each dimension of the indices tensor that is not the gather - dimension cannot be greater than the corresponding dimension in the input - tensor. - }]; - - let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis); - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $src `[` $indices `]` attr-dict `:` - functional-type(operands, results) - }]; - - let hasVerifier = 1; -} - -// -// Print Op -// -def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { - let arguments = ( - ins - StrAttr:$prefix, - BoolAttr:$hex, - Variadic>:$args, - DenseI32ArrayAttr:$isSigned - ); - let summary = "Device-side print, as in CUDA for debugging"; - let description = [{ - `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. - format are generated automatically from the arguments. - }]; - let assemblyFormat = [{ - $prefix attr-dict (`:` $args^ `:` type($args))? - }]; -} - -// -// Assert Op -// -def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { - let summary = "Device-side assert, as in CUDA for correctness checking"; - let description = [{ - `tt.assert` takes a condition tensor and a message string. - If the condition is false, the message is printed, and the program is aborted. - }]; - let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); - let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; -} - -// -// Make Tensor Pointer Op -// -def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", - [Pure, - SameVariadicOperandSize, - TypesMatchWith<"infer pointer type from the result type", - "result", "base", - "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { - let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; - - let description = [{ - `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a - pointer to the block tensor, e.g. returns a type of `tt.ptr>`. - }]; - - // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. - let arguments = (ins - TT_Ptr:$base, - Variadic:$shape, - Variadic:$strides, - Variadic:$offsets, - DenseI32ArrayAttr:$order - ); - - let results = (outs TT_TensorPtr:$result); - - // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly - // Add additional `[]` to increase readability and split variadic lists - let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; - - let builders = [ - OpBuilder<(ins - "Value":$base, - "ValueRange":$shape, - "ValueRange":$strides, - "ValueRange":$offsets, - "ArrayRef":$tensorShape, - "ArrayRef":$order - )> - ]; -} - -// The following ops, including `call`, `func`, and `return` are copied and modified from -// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td -// We could revert it back once MLIR has a better inliner interface. -// -// Function Ops -// -def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { - let summary = "call operation"; - let description = [{ - The `tt.call` operation represents a direct call to a function that is - within the same symbol scope as the call. The operands and result types of - the call must match the specified function type. The callee is encoded as a - symbol reference attribute named "callee". - - Example: - - ```mlir - %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 - ``` - }]; - - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); - - let builders = [ - OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", SymbolRefAttr::get(callee)); - $_state.addTypes(callee.getFunctionType().getResults()); - }]>, - OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", callee); - $_state.addTypes(results); - }]>, - OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); - }]>, - OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), - results, operands); - }]>]; - - let extraClassDeclaration = [{ - FunctionType getCalleeType() { - return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); - } - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - /// Return the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return (*this)->getAttrOfType("callee"); - } - - /// Set the callee for this operation. - void setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); - } - - // Required by CallOpInterface. - MutableOperandRange getArgOperandsMutable() { - return getOperandsMutable(); - } - - }]; - - let assemblyFormat = [{ - $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) - }]; -} - -def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { - let summary = "An operation with a name containing a single `SSACFG` region"; - let description = [{ - Operations within the function cannot implicitly capture values defined - outside of the function, i.e. Functions are `IsolatedFromAbove`. All - external references must use function arguments or attributes that establish - a symbolic connection (e.g. symbols referenced by name via a string - attribute like SymbolRefAttr). An external function declaration (used when - referring to a function declared in some other module) has no body. While - the MLIR textual form provides a nice inline syntax for function arguments, - they are internally represented as “block arguments” to the first block in - the region. - - Only dialect attribute names may be specified in the attribute dictionaries - for function arguments, results, or the function itself. - - Example: - - ```mlir - // External function definitions. - tt.func @abort() - tt.func @scribble(i32, i64, memref) -> f64 - - // A function that returns its argument twice: - tt.func @count(%x: i64) -> (i64, i64) - attributes {fruit: "banana"} { - return %x, %x: i64, i64 - } - - // A function with an argument attribute - tt.func @example_fn_arg(%x: i32 {swift.self = unit}) - - // A function with a result attribute - tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) - - // A function with an attribute - tt.func @example_fn_attr() attributes {dialectName.attrName = false} - ``` - }]; - - let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs); - let regions = (region AnyRegion:$body); - - let builders = [OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs, - CArg<"ArrayRef", "{}">:$argAttrs) - >]; - let extraClassDeclaration = [{ - //===------------------------------------------------------------------===// - // CallableOpInterface - //===------------------------------------------------------------------===// - - /// Returns the region on the current operation that is callable. This may - /// return null in the case of an external callable object, e.g. an external - /// function. - ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } - - /// Returns the results types that the callable region produces when - /// executed. - ArrayRef getCallableResults() { return getFunctionType().getResults(); } - - /// Returns the argument attributes for all callable region arguments or - /// null if there are none. - ::mlir::ArrayAttr getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); - } - - /// Returns the result attributes for all callable region results or - /// null if there are none. - ::mlir::ArrayAttr getCallableResAttrs() { - return getResAttrs().value_or(nullptr); - } - - //===------------------------------------------------------------------===// - // FunctionOpInterface Methods - //===------------------------------------------------------------------===// - - /// Returns the argument types of this function. - ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } - - /// Returns the result types of this function. - ArrayRef getResultTypes() { return getFunctionType().getResults(); } - - //===------------------------------------------------------------------===// - // SymbolOpInterface Methods - //===------------------------------------------------------------------===// - - bool isDeclaration() { return isExternal(); } - }]; - let hasCustomAssemblyFormat = 1; -} - -def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { - let summary = "Function return operation"; - let description = [{ - The `tt.return` operation represents a return operation within a function. - The operation takes variable number of operands and produces no results. - The operand number and types must match the signature of the function - that contains the operation. - - Example: - - ```mlir - tt.func @foo() : (i32, f8) { - ... - tt.return %0, %1 : i32, f8 - } - ``` - }]; - - let arguments = (ins Variadic:$srcs); - - let builders = [OpBuilder<(ins), [{ - build($_builder, $_state, std::nullopt); - }]>]; - - let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; - let hasVerifier = 1; -} - - -def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ - MemoryEffects<[MemRead]>]> { - let summary = "Load from descriptor"; - let description = [{ - This operation will be lowered to Nvidia TMA load operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. - The destination tensor type and shape must match the descriptor otherwise the result is undefined. - - This is an escape hatch and is only there for testing/experimenting. - This op will be removed in the future. - }]; - let arguments = ( - ins - TT_PtrType:$desc_ptr, - Variadic:$indices, - DefaultValuedAttr:$cache, - DefaultValuedAttr:$evict - ); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $desc_ptr `[` $indices `]` - oilist( - `cacheModifier` `=` $cache | - `evictionPolicy` `=` $evict - ) - attr-dict `:` qualified(type($desc_ptr)) `->` type($result) - }]; -} - -def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ - MemoryEffects<[MemRead, MemWrite]>]> { - let summary = "store value based on descriptor"; - let description = [{ - This operation will be lowered to Nvidia TMA store operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. - The shape and types of `src` must match the descriptor otherwise the result is undefined. - - This is an escape hatch and is only there for testing/experimenting. - This op will be removed in the future. - }]; - let arguments = ( - ins - TT_PtrType:$desc_ptr, - TT_Tensor:$src, - Variadic:$indices - ); - - let assemblyFormat = [{ - $desc_ptr `[` $indices `]` `,` $src - attr-dict `:` qualified(type($desc_ptr)) `,` type($src) - }]; -} - -def TT_ExperimentalTensormapCreateOp: TT_Op< - "experimental_tensormap_create", - [ - MemoryEffects<[MemRead, MemWrite]>, - AttrSizedOperandSegments, - ] -> { - let summary = "Create a new TMA descriptor on device"; - let arguments = ( - ins - TT_PtrType:$desc_ptr, - TT_PtrType:$global_address, - Variadic:$box_dim, - Variadic:$global_dim, - Variadic:$global_stride, - Variadic:$element_stride, - ConfinedAttr]>:$elem_type, - ConfinedAttr]>:$interleave_layout, - ConfinedAttr]>:$swizzle_mode, - ConfinedAttr]>:$fill_mode - ); - let extraClassDeclaration = [{ - int32_t getRank() { - return getBoxDim().size(); - } - }]; - let assemblyFormat = [{ - $desc_ptr `,` $global_address `,` - `[` $box_dim `]` `,` - `[` $global_dim `]` `,` - `[` $global_stride `]` `,` - `[` $element_stride `]` - attr-dict `:` functional-type(operands, results) - }]; - - let hasVerifier = 1; -} - -def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< - "experimental_tensormap_fenceproxy_acquire", - [MemoryEffects<[MemWrite]>] -> { - let summary = "Acquire fence on a tensormap object"; - let arguments = (ins TT_PtrType:$desc_ptr); - let assemblyFormat = [{ - $desc_ptr attr-dict `:` qualified(type($desc_ptr)) - }]; -} - - -#endif // Triton_OPS diff --git a/third_party/ascend/triton_patch/lib/CMakeLists.txt b/third_party/ascend/triton_patch/lib/CMakeLists.txt deleted file mode 100644 index 0ca0f41c5..000000000 --- a/third_party/ascend/triton_patch/lib/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 5e601271e..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt deleted file mode 100644 index f33061b2d..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt deleted file mode 100644 index 3b7c3746a..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_triton_library(Patched_TritonIR - Dialect.cpp - Ops.cpp - Traits.cpp - Types.cpp - - DEPENDS - TritonTableGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRArithDialect - MLIRMathDialect - MLIRSCFDialect -) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp deleted file mode 100644 index dc2417712..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp +++ /dev/null @@ -1,139 +0,0 @@ -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/IR/DialectImplementation.h" - -#include "mlir/Transforms/InliningUtils.h" -#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" - -using namespace mlir; -using namespace mlir::triton; - -//===----------------------------------------------------------------------===// -// TritonDialect Dialect Interfaces -//===----------------------------------------------------------------------===// - -namespace { -struct TritonInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - auto funcOp = dyn_cast(callable); - if (!funcOp) - return true; - if (funcOp->hasAttr("noinline")) - return !funcOp->getAttrOfType("noinline").getValue(); - return true; - } - - bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const final { - return true; - } - - bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, - IRMapping &) const final { - return true; - } - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, Block *newDest) const final { - // Only return needs to be handled here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the return with a branch to the dest. - OpBuilder builder(op); - builder.create(op->getLoc(), newDest, - returnOp.getOperands()); - op->erase(); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } -}; - -struct TensorModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getRank(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementTypeBitWidth(); - } -}; - -struct MemDescModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getShape().size(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementType().getIntOrFloatBitWidth(); - } -}; - -} // namespace - -void TritonDialect::initialize() { - registerTypes(); - - addOperations< -#define GET_OP_LIST -#include "triton/Dialect/Triton/IR/Ops.cpp.inc" - >(); - - // We can also add interface here. - addInterfaces(); - - RankedTensorType::attachInterface(*getContext()); - MemDescType::attachInterface(*getContext()); -} - -Operation *TritonDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { - return arith::ConstantOp::materialize(builder, value, type, loc); -} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp deleted file mode 100644 index 87aca769f..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp +++ /dev/null @@ -1,1092 +0,0 @@ -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/Interfaces/FunctionImplementation.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" - -namespace mlir { -namespace triton { - -void LoadOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), - triton::GlobalMemory::get()); - if (getIsVolatile()) - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); -} - -} // namespace triton -} // namespace mlir - -#define GET_OP_CLASSES -#include "triton/Dialect/Triton/IR/Ops.cpp.inc" - -// enum attribute definitions -#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" - -namespace mlir { -namespace triton { - -//-- LoadOp -- -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - CacheModifier cache, EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, - cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, - padding, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, CacheModifier cache, EvictionPolicy evict, - bool isVolatile) { - LoadOp::build(builder, state, ptr, mask, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, mask, other, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - auto paddingAttr = - padding.has_value() - ? PaddingOptionAttr::get(builder.getContext(), padding.value()) - : PaddingOptionAttr(); - LoadOp::build(builder, state, ptr, mask, other, - builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, - evict, isVolatile); -} - -// load(ptr, splat(1), ...) -> load(ptr, ...) -// load(ptr, splat(0), other, ...) -> other -struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { - CanonicalizeMaskedLoadPattern(MLIRContext *context) - : OpRewritePattern(context, 1) {} - - LogicalResult matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const override { - auto mask = loadOp.getMask(); - if (!mask) - return failure(); - - auto constantMask = mask.getDefiningOp(); - if (!constantMask) - return failure(); - - auto splatMask = mlir::dyn_cast(constantMask.getValue()); - if (!splatMask) - return failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - } else { - // mask = splat(0) - - // If there's no "other", the value is "undef". Perhaps we want to - // optimize it in the future.x - auto otherVal = loadOp.getOther(); - if (!otherVal) - return failure(); - rewriter.replaceOp(loadOp, otherVal); - } - return success(); - } -}; - -void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//-- StoreOp -- -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, CacheModifier cache, EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, - /*boundaryCheck=*/{}, cache, evict); -} - -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, Value mask, CacheModifier cache, - EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, - cache, evict); -} - -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, ArrayRef boundaryCheck, - CacheModifier cache, EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, - builder.getDenseI32ArrayAttr(boundaryCheck), cache, - evict); -} - -// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) -// store(ptr, value, splat(0), ...) -> [none] -struct CanonicalizeMaskedStorePattern : public OpRewritePattern { - CanonicalizeMaskedStorePattern(MLIRContext *context) - : OpRewritePattern(context, 1) {} - - LogicalResult matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const override { - auto mask = storeOp.getMask(); - if (!mask) - return failure(); - - auto constantMask = mask.getDefiningOp(); - if (!constantMask) - return failure(); - - auto splatMask = mlir::dyn_cast(constantMask.getValue()); - if (!splatMask) - return failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), - storeOp.getEvict()); - } else { - // mask = splat(0) - rewriter.eraseOp(storeOp); - } - return success(); - } -}; - -void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//-- TransOp -- -OpFoldResult TransOp::fold(FoldAdaptor adaptor) { - // transpose(x, order=[0, 1, ...]) -> x - if (isIota(getOrder())) { - return getSrc(); - } - - // transpose(transpose(x)) -> transpose(x) - if (auto innerTrans = getSrc().getDefiningOp()) { - setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); - setOperand(innerTrans.getSrc()); - return getResult(); - } - - return {}; -} - -LogicalResult TransOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // type is the same as the input - auto argTy = cast(operands[0].getType()); - auto order = properties.as()->order.asArrayRef(); - SmallVector retShape = applyPermutation(argTy.getShape(), order); - - auto retEltTy = argTy.getElementType(); - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = dyn_cast(&dialect); - if (inferLayoutInterface - ->inferTransOpEncoding(argEncoding, order, retEncoding) - .failed()) { - return failure(); - } - } - if (auto memDescTy = dyn_cast(argTy)) { - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); - } else { - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -LogicalResult TransOp::verify() { - // Check that the op's `order` attribute is a permutation of the right length. - auto srcTy = getSrc().getType(); - - ArrayRef order = getOrder(); - if (order.size() != srcTy.getRank()) { - return emitError("order must have the same size as the rank of the " - "operand and result"); - } - - SmallVector sortedOrder(order); - llvm::sort(sortedOrder); - for (int32_t i = 0; i < sortedOrder.size(); i++) { - if (sortedOrder[i] != i) { - return emitError("order must be a permutation of [0, ..., rank - 1]"); - } - } - - return success(); -} - -//-- DotOp -- -LogicalResult -DotOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // type is the same as the accumulator - auto accTy = cast(operands[2].getType()); - inferredReturnTypes.push_back(accTy); - - // verify encodings - auto aEnc = cast(operands[0].getType()).getEncoding(); - auto bEnc = cast(operands[1].getType()).getEncoding(); - auto retEnc = accTy.getEncoding(); - if (aEnc) { - assert(bEnc && retEnc); - Dialect &dialect = retEnc.getDialect(); - auto interface = dyn_cast(&dialect); - if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) - return failure(); - if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) - return failure(); - } - return success(); -} - -LogicalResult DotOp::verify() { - auto aTy = getA().getType(); - auto bTy = getB().getType(); - if (aTy.getElementType().getIntOrFloatBitWidth() != - bTy.getElementType().getIntOrFloatBitWidth()) - return emitError( - "element types of operands A and B must have same bit width"); - auto aEncoding = aTy.getEncoding(); - auto bEncoding = bTy.getEncoding(); - if (!aEncoding && !bEncoding) - return success(); - // Verify that the encodings are valid. - if (!aEncoding || !bEncoding) - return emitError("mismatching encoding between A and B operands"); - auto accTy = getC().getType(); - auto retEnc = accTy.getEncoding(); - if (!retEnc) - return emitError("miss encoding of C operand"); - Dialect &dialect = retEnc.getDialect(); - auto interface = cast(&dialect); - return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, - bEncoding); -} - -//-- MakeRangeOp -- -OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { - // make_range(start, start + 1) -> constant(start) - if (adaptor.getStart() + 1 == adaptor.getEnd()) { - auto shapedType = cast(getType()); - return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); - } - return {}; -} - -LogicalResult MakeRangeOp::verify() { - int64_t start = getStartAttr().getInt(); - int64_t end = getEndAttr().getInt(); - if (start > end) { - return this->emitOpError() << "start must be less than or equal to end"; - } - auto ty = getType(); - if (ty.getShape().size() != 1) { - return this->emitOpError() << "return type must be a 1D tensor"; - } - if (end - start != ty.getShape()[0]) { - return this->emitOpError() - << "number of elements in returned tensor, " << ty.getShape()[0] - << ", must match size of range [" << start << ", " << end - << "), which has " << end - start << " elements"; - } - if (!ty.getElementType().isInteger(32)) { - return this->emitOpError() << "returned tensor must have i32 elements"; - } - return success(); -} - -//-- ReduceOp -- -static LogicalResult -inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, - SmallVectorImpl &inferredReturnTypes) { - auto retShape = argTy.getShape().vec(); - retShape.erase(retShape.begin() + axis); - if (retShape.empty()) { - // 0d-tensor -> scalar - inferredReturnTypes.push_back(retEltTy); - } else { - // nd-tensor where n >= 1 - // infer encoding - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = - dyn_cast(&dialect); - if (inferLayoutInterface - ->inferReduceOpEncoding(argEncoding, axis, retEncoding) - .failed()) { - llvm::report_fatal_error("failed to infer layout for ReduceOp"); - return failure(); - } - } - // create type - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -void ReduceOp::build(OpBuilder &builder, OperationState &state, - ValueRange operands, int axis) { - SmallVector inferredReturnTypes; - for (unsigned i = 0; i < operands.size(); ++i) { - auto argTy = cast(operands[i].getType()); - auto retEltTy = argTy.getElementType(); - (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); - } - - ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); -} - -LogicalResult ReduceOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - Properties *prop = properties.as(); - int axis = prop->axis.getInt(); - for (auto arg : operands) { - auto argTy = cast(arg.getType()); - auto retEltTy = argTy.getElementType(); - if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) - .failed()) { - return failure(); - } - } - return success(); -} - -// Helpers for Reductions and Scans -template LogicalResult verifyReduceScan(Op &op) { - if (op.getOperands().empty()) { - return op.emitOpError() << "must have at least 1 operand"; - } - if (op.getNumOperands() != op.getNumResults()) { - return op.emitOpError() << "must have the same number of inputs as outputs"; - } - - auto getElementType = [](Type ty) { - if (auto tensorType = dyn_cast(ty)) { - return tensorType.getElementType(); - } - return ty; - }; - - for (auto [opElemTy, resTy] : - llvm::zip(op.getElementTypes(), op.getResultTypes())) { - if (opElemTy != getElementType(resTy)) { - return op.emitOpError() << "operand types and result types must agree"; - } - } - return success(); -} - -template -static LogicalResult verifyRegionsImpl(Op &op) { - auto argElementTypes = op.getElementTypes(); - const auto &operands = op.getOperands(); - const auto numArgs = 2 * operands.size(); - auto &block = *op.getBody(); - if (block.getNumArguments() != numArgs) { - return op.emitOpError() << "nested block must take " << numArgs - << " arguments, but given block with " - << block.getNumArguments() << " arguments"; - } - unsigned i = 0; - const auto &blockArgTypes = block.getArgumentTypes(); - for (unsigned i = 0; i < numArgs; ++i) { - const auto &blockArgTy = blockArgTypes[i]; - const auto &argElemTy = argElementTypes[i % operands.size()]; - if (blockArgTy != argElemTy) { - return op.emitOpError() - << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << blockArgTy; - } - } - - auto terminator = dyn_cast(block.getTerminator()); - if (!terminator) { - return op.emitOpError() - << "combine operation must be terminated " - << "with a ReduceReturnOp but got " << block.getTerminator(); - } - const auto &combineResults = terminator->getOperands(); - if (combineResults.size() != operands.size()) { - return op.emitOpError() - << "expected combine operation to return " << operands.size() - << " values but got " << combineResults.size(); - } - for (unsigned i = 0; i < combineResults.size(); ++i) { - const auto &resultTy = combineResults[i].getType(); - const auto &argElemTy = argElementTypes[i]; - if (resultTy != argElemTy) { - return op.emitOpError() - << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << resultTy; - } - } - return success(); -} - -static llvm::SmallVector -getInputTypesImpl(const Operation::operand_range &operands) { - llvm::SmallVector srcTys; - srcTys.reserve(operands.size()); - for (const auto &ty : operands.getTypes()) { - srcTys.push_back(cast(ty)); - } - return srcTys; -} - -static llvm::SmallVector -getElementTypesImpl(const Operation::operand_range &operands) { - llvm::SmallVector srcElemTys; - srcElemTys.reserve(operands.size()); - for (const auto &op : operands) { - srcElemTys.push_back(cast(op.getType()).getElementType()); - } - return srcElemTys; -} - -LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } - -LogicalResult ReduceOp::verifyRegions() { - return verifyRegionsImpl(*this); -} - -llvm::SmallVector ReduceOp::getInputTypes() { - return getInputTypesImpl(this->getOperands()); -} - -llvm::SmallVector ReduceOp::getElementTypes() { - return getElementTypesImpl(this->getOperands()); -} - -unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } - -//-- ScanOp -- -void ScanOp::build(OpBuilder &builder, OperationState &state, - ValueRange operands, int axis, bool reverse) { - SmallVector inferredReturnTypes; - state.addAttribute("reverse", builder.getBoolAttr(reverse)); - for (auto arg : operands) - inferredReturnTypes.push_back(arg.getType()); - ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); -} - -LogicalResult -ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - for (auto arg : operands) - inferredReturnTypes.push_back(arg.getType()); - return success(); -} - -LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } - -LogicalResult ScanOp::verifyRegions() { - return verifyRegionsImpl(*this); -} - -llvm::SmallVector ScanOp::getInputTypes() { - return getInputTypesImpl(this->getOperands()); -} - -llvm::SmallVector ScanOp::getElementTypes() { - return getElementTypesImpl(this->getOperands()); -} - -unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } - -//-- SplatOp -- -OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { - auto value = adaptor.getSrc(); - if (!value) - return {}; - if (!isa(value)) - return {}; - auto shapedType = cast(getType()); - auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); - return ret; -} - -//-- ExpandDimsOp -- -LogicalResult ExpandDimsOp::inferReturnTypes( - MLIRContext *context, std::optional loc, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // infer shape - auto arg = operands[0]; - auto argTy = cast(arg.getType()); - auto retShape = argTy.getShape().vec(); - Properties *prop = properties.as(); - int axis = prop->axis.getInt(); - retShape.insert(retShape.begin() + axis, 1); - // infer encoding - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = dyn_cast(&dialect); - if (inferLayoutInterface - ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) - .failed()) - return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); - } - // create type - auto argEltTy = argTy.getElementType(); - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, argEltTy, retEncoding)); - return success(); -} - -LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, - PatternRewriter &rewriter) { - auto definingOp = op.getSrc().getDefiningOp(); - if (!definingOp) { - return failure(); - } - // expand_dims(splat) -> splat - if (auto splat = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); - return success(); - } - // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) - // - // On its own this doesn't do much, but consider - // broadcast(expand_dims(broadcast)) - // -> broadcast(broadcast(expand_dims)) - // -> broadcast(expand_dims) - if (auto broadcast = dyn_cast(definingOp)) { - auto src = broadcast.getSrc(); - auto srcTy = src.getType(); - SmallVector newExpandShape(srcTy.getShape()); - newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); - - // Infer the encoding of the new expand op, if encodings are present. - Attribute newExpandEnc; - if (auto srcEnc = srcTy.getEncoding()) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, - op.getLoc()) - .failed()) { - return emitOptionalError(op.getLoc(), - "failed to infer layout for ExpandDimsOp"); - } - } - - auto newExpandTy = RankedTensorType::get( - newExpandShape, srcTy.getElementType(), newExpandEnc); - auto newExpand = rewriter.create(op.getLoc(), newExpandTy, - src, op.getAxis()); - auto newBroadcast = rewriter.create( - broadcast.getLoc(), op.getType(), newExpand.getResult()); - rewriter.replaceOp(op, {newBroadcast.getResult()}); - return success(); - } - - return failure(); -} - -template -static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { - if (!value) - return {}; - - auto shapedType = cast(op.getType()); - if (auto denseElemsAttr = dyn_cast(value)) { - if (denseElemsAttr.isSplat()) { - return denseElemsAttr.resizeSplat(shapedType); - } else { - return denseElemsAttr.reshape(shapedType); - } - } - return {}; -} - -OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { - return foldViewLikeOp(*this, adaptor.getSrc()); -} - -//-- ReshapeOp -- -template -LogicalResult canonicalizeViewOrBroadcast(OpType op, - PatternRewriter &rewriter) { - auto definingOp = op.getSrc().getDefiningOp(); - if (!definingOp) { - return failure(); - } - - // view(view) -> view - if (auto parentView = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), - parentView->getOperands(), - parentView->getAttrs()); - return success(); - } - - // view(splat) -> splat - if (auto splat = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); - return success(); - } - - return failure(); -} - -LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { - if (!op.getAllowReorder() || op.getEfficientLayout()) - return failure(); - return canonicalizeViewOrBroadcast(op, rewriter); -} - -OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { - if (getType() == getSrc().getType()) { - // no-op - return getSrc(); - } - - return foldViewLikeOp(*this, adaptor.getSrc()); -} - -LogicalResult ReshapeOp::verify() { - auto dstTy = getType(); - auto srcTy = getSrc().getType(); - if (getType().getNumElements() != srcTy.getNumElements()) { - return emitError( - "number of src and dst elements of reshape must be the same"); - } - - Attribute srcEnc = srcTy.getEncoding(); - Attribute dstEnc = dstTy.getEncoding(); - if (!!srcEnc != !!dstEnc) { - return emitError("Op requires that either (a) src and dst both have " - "encodings, or (b) neither does."); - } - - if (srcEnc && !getAllowReorder()) { - Attribute inferredDstEnc; - if (cast(&srcEnc.getDialect()) - ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, - dstTy.getShape(), inferredDstEnc, - getLoc()) - .failed()) { - return emitError("This reshape is impossible without reordering, but " - "reordering is not allowed. Try choosing a different " - "encoding for the input tensor (or allow reordering)."); - } - if (inferredDstEnc != dstEnc) { - return emitError("Expected result encoding ") - << inferredDstEnc << " but was " << dstEnc; - } - } - - return success(); -} - -//-- FpToFpOp -- -LogicalResult FpToFpOp::verify() { - auto dstType = getType().getElementType(); - auto srcType = getSrc().getType().getElementType(); - if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && - (!getRounding().has_value())) { - return emitError("Rounding mode is required for FP downcast"); - } - return success(); -} - -//-- BroadcastOp -- -LogicalResult BroadcastOp::canonicalize(BroadcastOp op, - PatternRewriter &rewriter) { - return canonicalizeViewOrBroadcast(op, rewriter); -} - -OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { - if (getType() == getSrc().getType()) { - // no-op - return getSrc(); - } - - auto value = adaptor.getSrc(); - if (!value) - return {}; - - if (auto denseElemsAttr = dyn_cast(value)) { - auto shapedType = cast(getType()); - return denseElemsAttr.resizeSplat(shapedType); - } - return {}; -} - -LogicalResult BroadcastOp::verify() { - auto src = getSrc(); - auto srcTensorType = cast(src.getType()); - auto srcShape = srcTensorType.getShape(); - auto result = getResult(); - auto resultTensorType = cast(result.getType()); - auto resultShape = resultTensorType.getShape(); - if (srcShape.size() != resultShape.size()) { - return emitError("rank of source must be same as rank of result"); - } - for (int i = 0; i < srcShape.size(); i++) { - if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { - return emitError("Different dimensions at index ") - << i << " between source and result. " - << "Broadcast requires the source dimension to be 1."; - } - } - return success(); -} - -//-- MakeTensorPtrOp -- -void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, - Value base, ValueRange shape, ValueRange strides, - ValueRange offsets, ArrayRef tensorShape, - ArrayRef order) { - // Get pointer type from `base` - auto pointerType = cast(base.getType()); - assert(pointerType != nullptr); - - // Build type `tt.ptr>` - auto tensorType = RankedTensorType::get( - SmallVector(tensorShape.begin(), tensorShape.end()), - pointerType.getPointeeType()); - auto result = PointerType::get(tensorType, 1); - - return build(builder, state, result, base, shape, strides, offsets, - builder.getDenseI32ArrayAttr(order)); -} - -//-- AdvanceOp -- -OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { - // advance(ptr, 0, 0) -> ptr - SmallVector rawOffsets = getOffsets(); - auto offsets = getConstantIntValues(rawOffsets); - if (!offsets.has_value()) - return {}; - for (int64_t offset : offsets.value()) - if (offset != 0) - return {}; - return getPtr(); -} - -// The following ops, including `call`, `func`, and `return` are copied and -// modified from -// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp -// We could revert it back once MLIR has a better inliner interface. -//-- FuncOp -- -void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, - FunctionType type, ArrayRef attrs, - ArrayRef argAttrs) { - state.addAttribute(SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); - state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); - state.attributes.append(attrs.begin(), attrs.end()); - state.addRegion(); - - if (argAttrs.empty()) - return; - assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( - builder, state, argAttrs, /*resultAttrs=*/std::nullopt, - getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); -} - -ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { - auto buildFuncType = - [](Builder &builder, ArrayRef argTypes, ArrayRef results, - function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); -} - -void FuncOp::print(OpAsmPrinter &printer) { - function_interface_impl::printFunctionOp( - printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); -} - -// -- CallOp -- -LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // Check that the callee attribute was specified. - auto fnAttr = (*this).getProperties().callee; - if (!fnAttr) - return emitOpError("requires a 'callee' symbol reference attribute"); - FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); - if (!fn) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - - // Verify that the operand and result types match the callee. - auto fnType = fn.getFunctionType(); - if (fnType.getNumInputs() != getNumOperands()) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (getOperand(i).getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch: expected operand type ") - << fnType.getInput(i) << ", but provided " - << getOperand(i).getType() << " for operand number " << i; - - if (fnType.getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (getResult(i).getType() != fnType.getResult(i)) { - auto diag = emitOpError("result type mismatch at index ") << i; - diag.attachNote() << " op result types: " << getResultTypes(); - diag.attachNote() << "function result types: " << fnType.getResults(); - return diag; - } - - return success(); -} - -// -- ReturnOp -- -LogicalResult ReturnOp::verify() { - auto function = cast((*this)->getParentOp()); - - // The operand number and types must match the function signature. - const auto &results = function.getFunctionType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError("has ") - << getNumOperands() << " operands, but enclosing function (@" - << function.getName() << ") returns " << results.size(); - - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (getOperand(i).getType() != results[i]) - return emitError() << "type of return operand " << i << " (" - << getOperand(i).getType() - << ") doesn't match function result type (" - << results[i] << ")" - << " in function @" << function.getName(); - - return success(); -} - -// -- JoinOp -- -LogicalResult -JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // These should have been checked by tablegen-generated code. - assert(operands.size() == 2); - assert(operands[0].getType() == operands[1].getType()); - assert(isa(operands[0].getType())); - assert(isa(operands[1].getType())); - - Value lhs = operands[0]; - Value rhs = operands[1]; - auto srcTy = cast(lhs.getType()); - - SmallVector retShape(srcTy.getShape()); - retShape.push_back(2); - - Attribute srcEnc = srcTy.getEncoding(); - Attribute retEnc; - if (srcEnc) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferJoinOpEncoding(srcEnc, retEnc, location) - .failed()) { - return failure(); - } - } - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); - return success(); -} - -// -- SplitOp -- -LogicalResult SplitOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // These should have been checked by tablegen-generated code. - assert(operands.size() == 1); - assert(isa(operands[0].getType())); - - Value src = operands[0]; - auto srcTy = cast(src.getType()); - auto srcShape = srcTy.getShape(); - - if (srcShape.empty() || srcShape.back() != 2) { - return emitOptionalError(location, - "last dimension of input tensor must be 2"); - } - ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); - - Attribute srcEnc = srcTy.getEncoding(); - Attribute retEnc; - if (srcEnc) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferSplitOpEncoding(srcEnc, retEnc, location) - .failed()) { - return failure(); - } - } - auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); - inferredReturnTypes.push_back(retTy); - inferredReturnTypes.push_back(retTy); - return success(); -} - -// -- ElementwiseInlineAsmOp -- -void ElementwiseInlineAsmOp::getEffects( - SmallVectorImpl> - &effects) { - if (getPure()) - return; - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), - SideEffects::DefaultResource::get()); -} - -LogicalResult ElementwiseInlineAsmOp::verify() { - if (getNumOperands() >= 1) { - auto tensorType = dyn_cast(getOperand(0).getType()); - size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; - if (numInputElems % this->getPackedElement() != 0) { - return emitError("number of input elements ") - << numInputElems - << " must be a multiple of the op's packed_element attribute, " - << getPackedElement(); - } - } - return success(); -} - -// -- ExternElementwiseOp -- -void ExternElementwiseOp::getEffects( - SmallVectorImpl> - &effects) { - if (getPure()) - return; - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), - SideEffects::DefaultResource::get()); -} - -Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { - if (getPure()) - return Speculation::Speculatable; - return Speculation::NotSpeculatable; -} - -// -- ExperimentalTensormapCreateOp -- -LogicalResult ExperimentalTensormapCreateOp::verify() { - auto rank = getBoxDim().size(); - if (getGlobalDim().size() != rank) { - return emitError("Rank mismatch for global dim. Got") - << getGlobalDim().size() << " but expected " << rank; - } - if (getGlobalStride().size() + 1 != rank) { - return emitError("Rank mismatch for global stride. Got") - << getGlobalStride().size() << " but expected " << rank - 1; - } - if (getElementStride().size() != rank) { - return emitError("Rank mismatch for element stride. Got") - << getElementStride().size() << " but expected " << rank; - } - return success(); -} - -// -- GatherOp -- -LogicalResult GatherOp::verify() { - RankedTensorType indicesTy = getIndices().getType(); - RankedTensorType srcTy = getSrc().getType(); - RankedTensorType resTy = getResult().getType(); - - if (indicesTy.getShape() != resTy.getShape()) { - return emitOpError("indices and output shapes must match"); - } - if (indicesTy.getEncoding() != resTy.getEncoding()) { - return emitOpError("indices and output encodings must match"); - } - if (srcTy.getElementType() != resTy.getElementType()) { - return emitOpError("input and output element types must match"); - } - if (srcTy.getRank() != indicesTy.getRank()) { - return emitOpError("input and indices ranks must match"); - } - if (getAxis() >= srcTy.getRank()) { - return emitOpError("gather dimension must be less than the input rank"); - } - for (int dim = 0; dim < indicesTy.getRank(); ++dim) { - if (dim == getAxis()) - continue; - if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { - return emitOpError("indices dimension ") - << dim << " must match the corresponding input dimension"; - } - } - - return success(); -} - -LogicalResult GatherOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - GatherOpAdaptor adaptor(operands, attributes, properties, regions); - auto indicesType = cast(adaptor.getIndices().getType()); - auto srcType = cast(adaptor.getSrc().getType()); - - // Shape and encoding of the indices with the element type of the src. - inferredReturnTypes.push_back( - RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), - indicesType.getEncoding())); - return success(); -} - -} // namespace triton -} // namespace mlir diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp deleted file mode 100644 index b43a9b56c..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp +++ /dev/null @@ -1,239 +0,0 @@ -#include "triton/Dialect/Triton/IR/Traits.h" - -#include - -#include "mlir/IR/TypeUtilities.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/Support/ErrorHandling.h" - -using namespace mlir; -namespace ttg = mlir::triton::gpu; - -static LogicalResult verifySameEncoding(Type typeA, Type typeB, - bool allowTensorPointerType) { - // TODO(Keren): the allowTensorPointerType argument is a hack to allow. - // The type checking code is kind of a mess with the current design. - auto getEncoding = [=](Type type) -> Attribute { - Attribute ret; - if (auto tensorType = dyn_cast(type)) { - ret = tensorType.getEncoding(); - } - if (!allowTensorPointerType) { - assert(!triton::isTensorPointerType(type)); - } - return ret; - }; - auto encodingA = getEncoding(typeA); - auto encodingB = getEncoding(typeB); - if (!encodingA || !encodingB) - return success(); - return encodingA == encodingB ? success() : failure(); -} - -LogicalResult -OpTrait::impl::verifySameOperandsEncoding(Operation *op, - bool allowTensorPointerType) { - if (failed(verifyAtLeastNOperands(op, 1))) - return failure(); - - auto type = op->getOperand(0).getType(); - for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) - if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) - return op->emitOpError() << "requires the same encoding for all operands"; - - return success(); -} - -LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( - Operation *op, bool allowTensorPointerType) { - if (op->getNumOperands() == 0) - return success(); - - if (failed(verifyAtLeastNOperands(op, 1)) || - failed(verifyAtLeastNResults(op, 1))) - return failure(); - - auto type = op->getOperand(0).getType(); - for (auto resultType : op->getResultTypes()) - if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) - return op->emitOpError() - << "requires the same encoding for all operands and results"; - - return verifySameOperandsEncoding(op, allowTensorPointerType); -} - -LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { - for (auto opType : op->getOperandTypes()) { - if (auto tensorType = dyn_cast(opType)) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxTensorNumElements) - return op->emitError("Maximum allowed number of elements is ") - << maxTensorNumElements << ", but " << *op - << " has more than that"; - // if ((numElements & (numElements - 1)) != 0) - // return op->emitError("Number of elements must be power-of-two, but ") - // << *op << " doesn't follow the rule (" << numElements << ")" - // << " elements"; - } - } - for (auto opType : op->getResultTypes()) { - if (auto tensorType = dyn_cast(opType)) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxTensorNumElements) - return op->emitError("Maximum allowed number of elements is ") - << maxTensorNumElements << ", but " << *op - << " has more than that"; - // if ((numElements & (numElements - 1)) != 0) - // return op->emitError("Number of elements must be power-of-two, but ") - // << *op << " doesn't follow the rule (" << numElements << ")" - // << " elements"; - } - } - return success(); -} - -// Check that the Triton layouts on op's operands and return types are valid. -// For example, we check that the number of warps per block in a Triton GPU -// blocked layout matches that of its module. -// -// It's a little weird to check these properties of a layout only when the -// layout is used in an op, since most of the properties don't actually depend -// on the op. They do depend on the *module*, though, and a layout is attached -// to a module only by virtue of being used in one of the module's ops. -LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { - auto module = op->getParentOfType(); - auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { - // Only ranked tensors can have layouts. - auto rankedTy = dyn_cast(val.getType()); - if (!rankedTy) - return success(); - - mlir::Attribute layout = rankedTy.getEncoding(); - if (!layout) - return success(); - - if (isa(layout)) - return makeErr() << "Shared layout is not allowed on tensor type."; - // TODO(jlebar): Currently this only checks blocked layouts, but other - // layouts also have invariants! - - // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. - if (auto blocked = dyn_cast(layout)) { - // A different verifier should have checked that the layout itself is - // valid, including that threads-per-warp has the same rank as - // warps-per-block etc. - auto layoutRank = blocked.getThreadsPerWarp().size(); - if (layoutRank != rankedTy.getRank()) { - return makeErr() << layout << ".\nLayout has rank " << layoutRank - << ", but the tensor it's attached to has rank " - << rankedTy.getRank() << "."; - } - - int moduleThreadsPerWarp = - ttg::TritonGPUDialect::getThreadsPerWarp(module); - int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); - if (layoutThreadsPerWarp != moduleThreadsPerWarp) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutThreadsPerWarp - << " threads per warp, but the module specifies " - << moduleThreadsPerWarp << " threads per warp."; - } - - int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); - int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); - if (layoutWarpsPerCTA != moduleWarpsPerCTA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutWarpsPerCTA - << " warps per CTA, but the module specifies " - << moduleWarpsPerCTA << " warps per CTA."; - } - - if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { - int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); - int64_t layoutCTAsPerCGA = - product(blocked.getCTALayout().getCTAsPerCGA()); - if (layoutCTAsPerCGA != moduleCTAsPerCGA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutCTAsPerCGA - << " CTAs per CGA, but the module specifies " - << moduleCTAsPerCGA << " CTAs per CGA."; - } - } - } - - return success(); - }; - - for (size_t i = 0; i < op->getNumOperands(); i++) { - auto operand = op->getOperand(i); - auto err = checkLayout(operand, [&]() { - // Stringify the operand using `printAsOperand`. This prints e.g. "%42" - // rather than the full definition. - std::string operandStr; - llvm::raw_string_ostream os(operandStr); - // If we don't assume verified, dump() will recursively call this - // function! - operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); - - return op->emitError("Operand ") - << i << " (" << operand << ") has an invalid layout: "; - }); - if (!err.succeeded()) - return err; - } - - for (size_t i = 0; i < op->getNumResults(); i++) { - auto result = op->getResult(i); - auto err = checkLayout(result, [&]() { - if (op->getNumResults() == 1) { - return op->emitError("Result has an invalid layout: "); - } else { - return op->emitError("Result ") << i << " has an invalid layout: "; - } - }); - if (!err.succeeded()) - return err; - } - - return success(); -} - -static ArrayRef getTypeShape(Type type) { - auto rankedType = dyn_cast(type); - if (auto ptrType = dyn_cast(type)) - rankedType = dyn_cast(ptrType.getPointeeType()); - return rankedType ? rankedType.getShape() : ArrayRef(); -} - -LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { - if (failed(verifyAtLeastNOperands(op, 1))) - return failure(); - - auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); - for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) - if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) - return op->emitOpError() << "requires the same shape for all operands"; - - return success(); -} - -LogicalResult -OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { - if (failed(verifyAtLeastNOperands(op, 1)) || - failed(verifyAtLeastNResults(op, 1))) - return failure(); - - auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); - for (auto type : op->getResultTypes()) - if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) - return op->emitOpError() - << "requires the same shape for all operands and results"; - - return verifySameLoadStoreOperandsShape(op); -} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp deleted file mode 100644 index 6e41e70a8..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp +++ /dev/null @@ -1,197 +0,0 @@ -#include "triton/Dialect/Triton/IR/Types.h" - -#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LLVM.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` - -using namespace mlir; -using namespace mlir::triton; - -#define GET_TYPEDEF_CLASSES -#include "triton/Dialect/Triton/IR/Types.cpp.inc" - -//===----------------------------------------------------------------------===// -// Triton Dialect -//===----------------------------------------------------------------------===// -void TritonDialect::registerTypes() { - addTypes< -#define GET_TYPEDEF_LIST -#include "triton/Dialect/Triton/IR/Types.cpp.inc" - >(); -} - -Type PointerType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - - Type pointeeType; - if (parser.parseType(pointeeType)) - return Type(); - - int addressSpace = 1; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseInteger(addressSpace)) - return Type(); - } - - if (parser.parseGreater()) - return Type(); - - return PointerType::get(pointeeType, addressSpace); -} - -void PointerType::print(AsmPrinter &printer) const { - if (getAddressSpace() == 1) { - printer << "<" << getPointeeType() << ">"; - } else { - printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; - } -} - -static constexpr llvm::StringRef kMutableMemory = "mutable"; - -Type MemDescType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) - return Type(); - - // Parse the element type. - Type elementType; - if (parser.parseType(elementType)) - return Type(); - - Attribute encoding; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(encoding)) - return Type(); - } - bool mutableMemory = false; - Attribute memorySpace; - if (succeeded(parser.parseOptionalComma())) { - if (failed(parser.parseOptionalKeyword(kMutableMemory))) { - if (parser.parseAttribute(memorySpace)) - return Type(); - } else { - mutableMemory = true; - } - } - if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { - if (parser.parseOptionalKeyword(kMutableMemory)) - return Type(); - mutableMemory = true; - } - if (parser.parseGreater()) - return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory); -} - -void MemDescType::print(AsmPrinter &printer) const { - printer << "<"; - for (auto dim : getShape()) - printer << dim << "x"; - printer << getElementType(); - if (getEncoding()) - printer << ", " << getEncoding(); - if (getMemorySpace()) - printer << ", " << getMemorySpace(); - if (getMutableMemory()) - printer << ", " << kMutableMemory; - printer << ">"; -} - -namespace mlir { - -namespace triton { - -unsigned getPointeeBitWidth(Type type) { - auto pointeeType = getPointeeType(type); - if (auto tensorTy = dyn_cast(pointeeType)) - return tensorTy.getElementType().getIntOrFloatBitWidth(); - return pointeeType.getIntOrFloatBitWidth(); -} - -Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorTy = dyn_cast(type)) - return RankedTensorType::get(tensorTy.getShape(), i1Type, - tensorTy.getEncoding()); - return i1Type; -} - -Type getPointeeType(Type type) { - if (auto tensorTy = dyn_cast(type)) { - // Tensor of pointers - auto shape = tensorTy.getShape(); - auto ptrType = dyn_cast(tensorTy.getElementType()); - Type pointeeType = ptrType.getPointeeType(); - return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); - } else if (auto ptrType = dyn_cast(type)) { - // scalar pointer - Type pointeeType = ptrType.getPointeeType(); - return pointeeType; - } - return type; -} - -Type getI32SameShape(Type type) { - auto i32Type = IntegerType::get(type.getContext(), 32); - if (auto tensorTy = dyn_cast(type)) - return RankedTensorType::get(tensorTy.getShape(), i32Type, - tensorTy.getEncoding()); - return i32Type; -} - -Type getPointerTypeSameShape(Type type) { - if (auto tensorTy = dyn_cast(type)) { - Type elementType = tensorTy.getElementType(); - auto shape = tensorTy.getShape(); - PointerType ptrType = PointerType::get(elementType, 1); - return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); - } else { - return PointerType::get(type, 1); - } -} - -Type getPointerTypeToElement(Type type) { - Type elementType = getElementTypeOrSelf(type); - PointerType ptrType = PointerType::get(elementType, 1); - return ptrType; -} - -// upstream Triton only uses address space 1 for Pointer Type -Type getPointerType(Type type, int addressSpace) { - return PointerType::get(type, addressSpace); -} - -int getAddressSpace(Type type) { - if (auto ptrType = dyn_cast(type)) - return ptrType.getAddressSpace(); - return 1; -} - -bool isTensorPointerType(Type type) { - if (auto ptrType = dyn_cast(type)) - return isa(ptrType.getPointeeType()); - return false; -} - -bool isTensorOrTensorPointerType(Type type) { - return isa(type) || isTensorPointerType(type); -} - -Type getElementTypeOfTensorPointerType(Type type) { - if (auto ptrType = dyn_cast(type)) - if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) - return tensorTy.getElementType(); - return {}; -} - -} // namespace triton - -} // namespace mlir diff --git a/third_party/ascend/triton_patch/python/src/ir.cc b/third_party/ascend/triton_patch/python/src/ir.cc deleted file mode 100644 index 637e15c42..000000000 --- a/third_party/ascend/triton_patch/python/src/ir.cc +++ /dev/null @@ -1,1771 +0,0 @@ -#include -#include -#include -#include - -#include "mlir/Bytecode/BytecodeWriter.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/FileUtilities.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Transforms/LocationSnapshot.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/Support/SourceMgr.h" - -namespace { - -namespace py = pybind11; -using namespace mlir; -using namespace triton; - -// A custom op builder that keeps track of the last location -class TritonOpBuilder { -public: - TritonOpBuilder(MLIRContext *context) { - builder = std::make_unique(context); - lastLoc = std::make_unique(builder->getUnknownLoc()); - } - - OpBuilder &getBuilder() { return *builder; } - - bool isLineInfoEnabled() { return lineInfoEnabled; } - - void setLastLoc(Location loc) { - if (lineInfoEnabled) - lastLoc = std::make_unique(loc); - } - - void setLastLoc(const std::string &fileName, int line, int column) { - auto context = builder->getContext(); - setLastLoc(FileLineColLoc::get(context, fileName, line, column)); - } - - Location getLastLoc() { - assert(lastLoc); - return *lastLoc; - } - - void setInsertionPointToStart(Block &block) { - if (!block.empty()) - setLastLoc(block.begin()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToStart(&block); - } - - void setInsertionPointToEnd(Block &block) { - if (!block.empty()) - setLastLoc(block.back().getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToEnd(&block); - } - - void setInsertionPointAfter(Operation &op) { - setLastLoc(op.getLoc()); - builder->setInsertionPointAfter(&op); - } - - void restoreInsertionPoint(OpBuilder::InsertPoint pt) { - if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) - setLastLoc(pt.getPoint()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->restoreInsertionPoint(pt); - } - - template OpTy create(Args &&...args) { - auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); - } - - // Overload to create or fold a single result operation. - template - std::enable_if_t(), Value> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - - // Overload to create or fold a zero result operation. - template - std::enable_if_t(), OpTy> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - -private: - std::unique_ptr builder; - std::unique_ptr lastLoc; - bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); -}; - -std::string locationToString(Location loc) { - std::string str; - llvm::raw_string_ostream os(str); - loc.print(os); - os.flush(); // Make sure all the content is dumped into the 'str' string - return str; -} - -void outputWarning(Location loc, const std::string &msg) { - std::string locStr = locationToString(loc); - - PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), - /*stack_level=*/2); -} - -} // anonymous namespace - -/*****************************************************************************/ -/* Python bindings for ir */ -/*****************************************************************************/ - -void init_triton_ir(py::module &&m) { - using ret = py::return_value_policy; - using namespace pybind11::literals; - - py::enum_(m, "PADDING_OPTION", py::module_local()) - .value("PAD_ZERO", PaddingOption::PAD_ZERO) - .value("PAD_NAN", PaddingOption::PAD_NAN) - .export_values(); - - py::enum_(m, "CACHE_MODIFIER", py::module_local()) - .value("NONE", CacheModifier::NONE) - .value("CA", CacheModifier::CA) - .value("CG", CacheModifier::CG) - .value("WB", CacheModifier::WB) - .value("CS", CacheModifier::CS) - .value("WT", CacheModifier::WT) - .value("CV", CacheModifier::CV) - .export_values(); - - py::enum_(m, "MEM_SEMANTIC", py::module_local()) - .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) - .value("ACQUIRE", MemSemantic::ACQUIRE) - .value("RELEASE", MemSemantic::RELEASE) - .value("RELAXED", MemSemantic::RELAXED) - .export_values(); - - py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) - .value("GPU", MemSyncScope::GPU) - .value("CTA", MemSyncScope::CTA) - .value("SYSTEM", MemSyncScope::SYSTEM) - .export_values(); - - py::enum_(m, "EVICTION_POLICY", py::module_local()) - .value("NORMAL", EvictionPolicy::NORMAL) - .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) - .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) - .export_values(); - - py::enum_(m, "ATOMIC_OP", py::module_local()) - .value("ADD", RMWOp::ADD) - .value("FADD", RMWOp::FADD) - .value("AND", RMWOp::AND) - .value("OR", RMWOp::OR) - .value("XOR", RMWOp::XOR) - .value("XCHG", RMWOp::XCHG) - .value("MAX", RMWOp::MAX) - .value("MIN", RMWOp::MIN) - .value("UMIN", RMWOp::UMIN) - .value("UMAX", RMWOp::UMAX); - - py::enum_(m, "ROUNDING_MODE", py::module_local()) - .value("RTZ", RoundingMode::RTZ) - .value("RTNE", RoundingMode::RTNE); - - py::enum_(m, "PROPAGATE_NAN", py::module_local()) - .value("NONE", PropagateNan::NONE) - .value("ALL", PropagateNan::ALL); - - py::enum_(m, "INPUT_PRECISION", py::module_local()) - .value("TF32", InputPrecision::TF32) - .value("TF32x3", InputPrecision::TF32x3) - .value("IEEE", InputPrecision::IEEE) - .value("HF32", InputPrecision::HF32) - .export_values(); - - py::enum_(m, "F8F6F4TY", py::module_local()) - .value("E4M3", F8F6F4Type::E4M3) - .value("E5M2", F8F6F4Type::E5M2) - .value("E2M3", F8F6F4Type::E2M3) - .value("E3M2", F8F6F4Type::E3M2) - .value("E2M1", F8F6F4Type::E2M1) - .export_values(); - - py::class_(m, "context", py::module_local()) - .def(py::init<>()) - .def("printOpOnDiagnostic", - [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) - .def("printStackTraceOnDiagnostic", - [](MLIRContext &self, bool v) { - self.printStackTraceOnDiagnostic(v); - }) - .def("disable_multithreading", - [](MLIRContext &self) { self.disableMultithreading(); }); - - py::class_(m, "source_mgr_diag", - py::module_local()) - .def(py::init()); - - m.def("load_dialects", [](MLIRContext &context) { - DialectRegistry registry; - registry.insert(); - mlir::LLVM::registerInlinerInterface(registry); - registerBuiltinDialectTranslation(registry); - registerLLVMDialectTranslation(registry); - mlir::LLVM::registerInlinerInterface(registry); - context.appendDialectRegistry(registry); - context.loadAllAvailableDialects(); - }); - - py::class_(m, "type", py::module_local()) - .def("is_integer", - [](Type &self, unsigned width) { return self.isInteger(width); }) - .def("is_fp16", &Type::isF16) - .def("__str__", [](Type &self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return os.str(); - }); - - py::class_(m, "function_type", py::module_local()) - .def("param_types", [](FunctionType &self) { - return std::vector(self.getInputs().begin(), - self.getInputs().end()); - }); - - py::class_(m, "location", py::module_local()) - .def("__str__", [](Location &self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return os.str(); - }); - - py::class_(m, "value", py::module_local()) - .def("set_attr", - [](Value &self, std::string &name, Attribute &attr) -> void { - if (Operation *definingOp = self.getDefiningOp()) - definingOp->setAttr(name, attr); - else { - auto arg = mlir::cast(self); - int id = arg.getArgNumber(); - std::string attrName = name + "_arg" + std::to_string(id); - Block *owner = arg.getOwner(); - if (owner->isEntryBlock() && - !isa(owner->getParentOp())) { - owner->getParentOp()->setAttr(attrName, attr); - } - } - }) - .def("get_context", &Value::getContext) - .def("replace_all_uses_with", - [](Value &self, Value &newValue) { - self.replaceAllUsesWith(newValue); - }) - .def("get_type", &Value::getType) - .def("id", [](Value &self) { - // The Value is identified by and compared with - // other Values via the underlying ValueImpl - return (uint64_t)self.getImpl(); - }); - - py::class_(m, "op_result", py::module_local()); - - py::class_(m, "block_argument", py::module_local()); - - py::class_(m, "region", py::module_local()) - .def("get_parent_region", &Region::getParentRegion, ret::reference) - .def("size", [](Region &self) { return self.getBlocks().size(); }) - .def("empty", &Region::empty) - .def("id", [](Region &self) { return (uint64_t)&self; }); - - py::class_(m, "block", py::module_local()) - .def("arg", - [](Block &self, int index) -> BlockArgument { - if (index >= self.getNumArguments()) - throw pybind11::index_error("Block argument index out of range"); - return self.getArgument(index); - }) - .def("add_argument", - [](Block &self, Type ty) { - auto loc = UnknownLoc::get(ty.getContext()); - self.addArgument(ty, loc); - }) - .def("get_num_arguments", &Block::getNumArguments) - .def("get_argument", &Block::getArgument) - .def("dump", &Block::dump) - .def("move_before", - [](Block &self, Block &dst) { self.moveBefore(&dst); }) - .def("insert_before", &Block::insertBefore) - .def("get_parent", &Block::getParent, ret::reference) - .def("merge_block_before", - [](Block &self, Block &dst) { - // ref: RewriterBase::mergeBlocks() - if (self.getNumArguments() != 0) - throw std::runtime_error( - "This block has arguments, don't merge"); - dst.getOperations().splice(dst.begin(), self.getOperations()); - self.dropAllUses(); - self.erase(); - }) - .def("replace_use_in_block_with", - [](Block &self, Value &v, Value &newVal) { - v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { - Operation *user = operand.getOwner(); - Block *currentBlock = user->getBlock(); - while (currentBlock) { - if (currentBlock == &self) - return true; - // Move up one level - currentBlock = - currentBlock->getParent()->getParentOp()->getBlock(); - } - return false; - }); - }) - .def("__str__", - [](Block &self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return str; - }) - .def("has_terminator", - [](Block &self) { - return !self.empty() && - self.back().hasTrait(); - }) - .def("has_return", - [](Block &self) { - return !self.empty() && - self.back().hasTrait(); - }) - .def("erase", [](Block &self) { self.erase(); }) - .def("id", [](Block &self) { return (uint64_t)&self; }); - - py::class_(m, "attribute", py::module_local()); - py::class_(m, "integer_attr", py::module_local()); - py::class_(m, "bool_attr", py::module_local()); - - // Ops - py::class_(m, "OpState", py::module_local()) - .def("set_attr", - [](OpState &self, std::string &name, Attribute &attr) -> void { - self->setAttr(name, attr); - }) - .def("get_num_results", - [](OpState &self) -> unsigned { return self->getNumResults(); }) - .def("get_result", - [](OpState &self, unsigned idx) -> Value { - if (idx >= self->getNumResults()) - throw pybind11::index_error("Op result index out of range"); - return self->getResult(idx); - }) - .def( - "get_region", - [](OpState &self, unsigned idx) -> Region & { - if (idx >= self->getNumRegions()) - throw pybind11::index_error("Op region index out of range"); - return self->getRegion(idx); - }, - ret::reference) - .def( - "get_body", - [](scf::ForOp &self, unsigned idx) -> Block * { - if (idx >= self->getNumRegions()) - throw pybind11::index_error("Op region index out of range"); - return self.getBody(idx); - }, - ret::reference) - .def("dump", [](OpState &self) { self->dump(); }) - .def("__str__", - [](OpState &self) -> std::string { - std::string str; - llvm::raw_string_ostream os(str); - auto printingFlags = OpPrintingFlags(); - printingFlags.enableDebugInfo(); - self->print(os, printingFlags); - return str; - }) - .def("append_operand", - [](OpState &self, Value &val) { - self->insertOperands(self->getNumOperands(), val); - }) - .def("verify", [](OpState &self) -> bool { - return succeeded(verify(self.getOperation())); - }); - // scf Ops - py::class_(m, "ForOp", py::module_local()) - .def("get_induction_var", &scf::ForOp::getInductionVar); - - py::class_(m, "IfOp", py::module_local()) - .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) - .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) - .def("get_then_yield", &scf::IfOp::thenYield) - .def("get_else_yield", &scf::IfOp::elseYield); - py::class_(m, "YieldOp", py::module_local()); - py::class_(m, "WhileOp", py::module_local()) - .def("get_before", &scf::WhileOp::getBefore, ret::reference) - .def("get_after", &scf::WhileOp::getAfter, ret::reference); - py::class_(m, "ConditionOp", py::module_local()); - - py::class_>( - m, "operation", py::module_local()) - .def("get_name", - [](Operation &self) { - llvm::StringRef opName = self.getName().getStringRef(); - return opName.str(); - }) - .def("get_num_operands", &Operation::getNumOperands) - .def("get_operand", &Operation::getOperand) - .def("get_num_results", &Operation::getNumResults) - .def("get_result", &Operation::getResult) - .def("get_num_regions", &Operation::getNumRegions) - .def("get_region", &Operation::getRegion, ret::reference) - .def("get_block", &Operation::getBlock, ret::reference) - .def("get_str_attr", - [](Operation &self, const std::string &name) -> py::object { - auto ret = self.getAttrOfType(name); - if (!ret) - return py::none(); - return py::str(ret.getValue().str()); - }) - .def("get_bool_attr", - [](Operation &self, const std::string &name) -> py::object { - auto ret = self.getAttrOfType(name); - if (!ret) - return py::none(); - return py::bool_(ret.getValue()); - }) - .def("get_flat_symbol_ref_attr", - [](Operation &self, const std::string &name) -> py::object { - auto ret = self.getAttrOfType(name); - if (!ret) - return py::none(); - return py::str(ret.getValue().str()); - }); - - // dynamic_attr is used to transfer ownership of the MLIR context to the - // module - py::class_(m, "module", py::module_local(), - py::dynamic_attr()) - .def("dump", &ModuleOp::dump) - .def("str", - [](ModuleOp &self) -> std::string { - std::string str; - llvm::raw_string_ostream os(str); - auto printingFlags = OpPrintingFlags(); - printingFlags.enableDebugInfo(); - self.print(os, printingFlags); - return str; - }) - .def("push_back", - [](ModuleOp &self, FuncOp &funcOp) -> void { - self.push_back(funcOp); - }) - .def("has_function", - [](ModuleOp &self, std::string &funcName) -> bool { - if (self.lookupSymbol(funcName)) - return true; - return false; - }) - .def("get_function", - [](ModuleOp &self, std::string &funcName) -> FuncOp { - return self.lookupSymbol(funcName); - }) - .def("get_int_attr", - [](ModuleOp &self, std::string name) -> py::object { - auto ret = self->getAttrOfType(name); - if (!ret) - return py::none(); - return py::int_(ret.getInt()); - }) - .def("create_location_snapshot", - [](ModuleOp &self, const std::string &fileName) -> void { - generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), - /*fileName=*/fileName, - /*op=*/self, /*flags=*/{}); - }) - .def("walk", - [](ModuleOp &self, const std::function &fn) { - self.walk(fn); - }); - - m.def("make_attr", [](const std::vector &values, MLIRContext &context) { - return mlir::cast(DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(values.size())}, - IntegerType::get(&context, 32)), - values)); - }); - - m.def( - "parse_mlir_module", - [](const std::string &inputFilename, MLIRContext &context) { - // parse module - OwningOpRef module = - parseSourceFile(inputFilename, &context); - if (!module) - throw std::runtime_error("Parse MLIR file failed."); - return module->clone(); - }, - ret::take_ownership); - - py::class_(m, "function", py::module_local()) - // .def_property_readonly("attrs", &ir::function::attrs) - // .def("add_attr", &ir::function::add_attr); - .def("args", - [](FuncOp &self, unsigned idx) -> BlockArgument { - if (idx >= self.getNumArguments()) - throw pybind11::index_error( - "Function argument index out of range"); - return self.getArgument(idx); - }) - .def( - "add_entry_block", - [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, - ret::reference) - .def( - "set_arg_attr", - [](FuncOp &self, int arg_no, const std::string &name, int val) { - if (arg_no >= self.getNumArguments()) - throw pybind11::index_error( - "Function argument index out of range"); - // set arg attributes "name" to value "val" - auto attrTy = IntegerType::get(self.getContext(), 32); - self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); - }, - ret::reference) - // .def("has_attr", &::FuncOp::hasAttr) - .def("finalize", - [](FuncOp &self) -> void { - // Check if the result of tl.advance is used - self.walk([&](AdvanceOp op) { - if (op->getResult(0).use_empty()) - outputWarning(op->getLoc(), "The result of tl.advance is not " - "being used. Note that tl.advance " - "does not have any side effects. " - "To move the block pointer, you " - "need to assign the result of " - "tl.advance to a variable."); - }); - }) - .def_property_readonly("type", &FuncOp::getFunctionType) - .def("reset_type", &FuncOp::setType); - - py::class_(m, "InsertPoint", py::module_local()); - - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) - .def(py::init()) - // getters - .def("create_module", - [](TritonOpBuilder &self) -> ModuleOp { - return self.create(); - }) - // insertion block/point - .def("set_insertion_point_to_start", - [](TritonOpBuilder &self, Block &block) -> void { - self.setInsertionPointToStart(block); - }) - .def("set_insertion_point_to_end", - [](TritonOpBuilder &self, Block &block) { - self.setInsertionPointToEnd(block); - }) - .def("set_insertion_point_after", - [](TritonOpBuilder &self, Operation &op) { - self.setInsertionPointAfter(op); - }) - .def( - "get_insertion_block", - [](TritonOpBuilder &self) -> Block * { - return self.getBuilder().getInsertionBlock(); - }, - ret::reference) - .def("get_insertion_point", - [](TritonOpBuilder &self) { - return self.getBuilder().saveInsertionPoint(); - }) - .def("restore_insertion_point", - [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { - self.restoreInsertionPoint(pt); - }) - // Attr - .def("get_bool_attr", - [](TritonOpBuilder &self, bool value) { - return self.getBuilder().getBoolAttr(value); - }) - .def("get_int32_attr", - [](TritonOpBuilder &self, int32_t value) { - return self.getBuilder().getI32IntegerAttr(value); - }) - // Use arith.ConstantOp to create constants - // Constants - .def("get_int1", - [](TritonOpBuilder &self, bool v) -> Value { - return Value(self.create( - v, self.getBuilder().getI1Type())); - }) - .def("get_int8", - [](TritonOpBuilder &self, int64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI8Type())); - }) - .def("get_int16", - [](TritonOpBuilder &self, int64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI16Type())); - }) - .def("get_int32", - [](TritonOpBuilder &self, int64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI32Type())); - }) - .def("get_int64", - [](TritonOpBuilder &self, int64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI64Type())); - }) - .def("get_uint8", - [](TritonOpBuilder &self, uint64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI8Type())); - }) - .def("get_uint16", - [](TritonOpBuilder &self, uint64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI16Type())); - }) - .def("get_uint32", - [](TritonOpBuilder &self, uint64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI32Type())); - }) - .def("get_uint64", - [](TritonOpBuilder &self, uint64_t v) -> Value { - return Value(self.create( - v, self.getBuilder().getI64Type())); - }) - .def("get_bf16", - [](TritonOpBuilder &self, float v) -> Value { - auto type = self.getBuilder().getBF16Type(); - return self.create( - APFloat(type.getFloatSemantics(), std::to_string(v)), type); - }) - .def("get_fp16", - [](TritonOpBuilder &self, float v) -> Value { - return self.create( - self.getBuilder().getF16FloatAttr(v)); - }) - .def("get_fp32", - [](TritonOpBuilder &self, float v) -> Value { - return self.create( - self.getBuilder().getF32FloatAttr(v)); - }) - .def("get_fp64", - [](TritonOpBuilder &self, double v) -> Value { - return self.create( - self.getBuilder().getF64FloatAttr(v)); - }) - .def("get_null_value", - [](TritonOpBuilder &self, Type type) -> Value { - if (auto floatTy = dyn_cast(type)) - return self.create( - APFloat(floatTy.getFloatSemantics(), 0), floatTy); - else if (auto intTy = dyn_cast(type)) - return self.create(0, intTy); - else - throw std::runtime_error("Not implemented"); - }) - .def("get_all_ones_value", - [](TritonOpBuilder &self, Type type) -> Value { - uint64_t val = 0xFFFFFFFFFFFFFFFF; - if (auto intTy = dyn_cast(type)) - return self.create(val, intTy); - else - throw std::runtime_error("Not implemented"); - }) - - // Types - .def("get_void_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getNoneType(); - }) - .def("get_int1_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getI1Type(); - }) // or ret::copy? - .def("get_int8_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getI8Type(); - }) - .def("get_int16_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(16); - }) - .def("get_int32_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getI32Type(); - }) - .def("get_int64_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getI64Type(); - }) - .def("get_fp8e4nv_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); - }) - .def("get_fp8e4b8_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); - }) - .def("get_fp8e4b15_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getI8Type(); - }) - .def("get_fp8e5_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); - }) - .def("get_fp8e5b16_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); - }) - .def("get_half_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getF16Type(); - }) - .def("get_bf16_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getBF16Type(); - }) - .def("get_float_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getF32Type(); - }) - .def("get_double_ty", - [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getF64Type(); - }) - .def("get_ptr_ty", - [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { - return PointerType::get(type, addrSpace); - }) - .def("get_block_ty", - [](TritonOpBuilder &self, Type &elementType, - std::vector &shape) -> Type { - return RankedTensorType::get(shape, elementType); - }) - .def("get_function_ty", - [](TritonOpBuilder &self, std::vector inTypes, - std::vector outTypes) -> Type { - return self.getBuilder().getFunctionType(inTypes, outTypes); - }) - // locs - .def("set_loc", - [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) - .def("set_loc", - [](TritonOpBuilder &self, const std::string &fileName, int line, - int column) { self.setLastLoc(fileName, line, column); }) - .def("get_loc", - [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) - - // Ops - .def("get_or_insert_function", - [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, - Type &funcType, std::string &visibility, - bool noinline) -> FuncOp { - if (Operation *funcOperation = module.lookupSymbol(funcName)) - return llvm::dyn_cast(funcOperation); - if (auto funcTy = dyn_cast(funcType)) { - llvm::SmallVector attrs = { - NamedAttribute( - self.getBuilder().getStringAttr("sym_visibility"), - self.getBuilder().getStringAttr(visibility)), - NamedAttribute(self.getBuilder().getStringAttr("noinline"), - self.getBuilder().getBoolAttr(noinline))}; - return self.create(funcName, funcTy, attrs); - } - throw std::invalid_argument("invalid function type"); - }) - .def( - "create_block", - [](TritonOpBuilder &self) -> Block * { - Region *parent = self.getBuilder().getBlock()->getParent(); - return self.getBuilder().createBlock(parent); - }, - ret::reference) - .def( - "create_block_with_parent", - [](TritonOpBuilder &self, Region &parent, - std::vector &argTypes) -> Block * { - // TODO: update arg loc - auto loc = self.getBuilder().getUnknownLoc(); - llvm::SmallVector argLocs(argTypes.size(), loc); - return self.getBuilder().createBlock(&parent, {}, argTypes, - argLocs); - }, - ret::reference) - .def( - "new_block", - [](TritonOpBuilder &self) -> Block * { return new Block(); }, - ret::reference) - // Function - .def("ret", - [](TritonOpBuilder &self, std::vector &vals) -> OpState { - return self.create(vals); - }) - .def("call", - [](TritonOpBuilder &self, FuncOp &func, std::vector &args) - -> OpState { return self.create(func, args); }) - // Unstructured control flow - .def("create_cond_branch", - [](TritonOpBuilder &self, Value condition, Block *trueDest, - Block *falseDest) -> OpState { - return self.create(condition, trueDest, - falseDest); - }) - .def("create_branch", - [](TritonOpBuilder &self, Block *dest, std::vector &args) - -> OpState { return self.create(dest, args); }) - // Structured control flow - .def("create_for_op", - [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, - std::vector &initArgs) -> scf::ForOp { - return self.create(lb, ub, step, initArgs); - }) - .def("create_if_op", - [](TritonOpBuilder &self, std::vector &retTypes, - Value &condition, bool withElse) -> scf::IfOp { - return self.create(retTypes, condition, withElse); - }) - .def("create_yield_op", - [](TritonOpBuilder &self, std::vector &yields) - -> scf::YieldOp { return self.create(yields); }) - .def("create_while_op", - [](TritonOpBuilder &self, std::vector &retTypes, - std::vector &initArgs) -> scf::WhileOp { - return self.create(retTypes, initArgs); - }) - .def("create_condition_op", - [](TritonOpBuilder &self, Value &cond, - std::vector &args) -> scf::ConditionOp { - return self.create(cond, args); - }) - - // miscellaneous - .def("create_make_range", - [](TritonOpBuilder &self, int start, int end) -> Value { - auto retType = RankedTensorType::get( - {end - start}, self.getBuilder().getI32Type()); - return self.create(retType, start, end); - }) - - // Cast instructions - // Conversions for custom FP types (FP8 and non-standard rounding modes) - .def("create_fp_to_fp", - [](TritonOpBuilder &self, Value &src, Type &dstType, - std::optional roundingMode) -> Value { - if (roundingMode.has_value()) - return self.create( - dstType, src, - RoundingModeAttr::get(self.getBuilder().getContext(), - roundingMode.value())); - else - return self.create(dstType, src); - }) - // Conversions for standard LLVM builtin types - .def("create_bitcast", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_si_to_fp", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_ui_to_fp", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_fp_to_si", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_fp_to_ui", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_fp_ext", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_fp_trunc", - [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { - return self.create(dstType, src); - }) - .def("create_int_cast", - [](TritonOpBuilder &self, Value &src, Type &dstType, - bool isSigned) -> Value { - // get element type if necessary - Type srcType = src.getType(); - auto srcTensorType = dyn_cast(srcType); - auto dstTensorType = dyn_cast(dstType); - Type srcEltType = srcType; - Type dstEltType = dstType; - if (dstTensorType && srcTensorType) { - dstEltType = dstTensorType.getElementType(); - srcEltType = srcTensorType.getElementType(); - } - unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); - unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); - if (srcWidth == dstWidth) - return self.create(dstType, src); - else if (srcWidth > dstWidth) - return self.create(dstType, src); - else if (isSigned) - return self.create(dstType, src); - else - return self.create(dstType, src); - }) - .def("create_to_index", - [](TritonOpBuilder &self, Value &input) -> Value { - return self.create( - self.getBuilder().getIndexType(), input); - }) - .def("create_index_to_si", - [](TritonOpBuilder &self, Value &input) -> Value { - return self.create( - self.getBuilder().getI64Type(), input); - }) - .def("create_fmul", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_fdiv", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_frem", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_fadd", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_fsub", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_mul", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_umulhi", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_sdiv", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_udiv", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_srem", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_urem", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_add", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_sub", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_fma", - [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { - return Value(self.create(a, b, c)); - }) - .def("create_shl", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_lshr", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_ashr", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_minsi", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_minui", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - // minimumf follows the torch.minimum convention and returns NaN if either - // operand is NaN - .def("create_minimumf", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - // minnumf follows the torch.fmin convention and returns the non-NaN - // operand - .def("create_minnumf", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_maxsi", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_maxui", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - // maximumf follows the torch.maximum convention and returns NaN if either - // operand is NaN - .def("create_maximumf", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - // maxnumf follows the torch.fmax convention and returns the non-NaN - // operand - .def("create_maxnumf", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - .def("create_clampf", - [](TritonOpBuilder &self, Value &input, Value &min, Value &max, - PropagateNan propagateNan) -> Value { - return Value(self.create(input, min, max, propagateNan)); - }) - .def("create_precise_sqrt", - [](TritonOpBuilder &self, Value &input) -> Value { - return Value(self.create(input)); - }) - .def("create_precise_divf", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return Value(self.create(lhs, rhs)); - }) - // AddPtr (similar to GEP) - .def("create_addptr", - [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { - return self.create(ptr.getType(), ptr, offset); - }) - // Comparison (int) - .def("create_icmpSLE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::sle, lhs, - rhs); - }) - .def("create_icmpSLT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::slt, lhs, - rhs); - }) - .def("create_icmpSGE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::sge, lhs, - rhs); - }) - .def("create_icmpSGT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::sgt, lhs, - rhs); - }) - .def("create_icmpULE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::ule, lhs, - rhs); - }) - .def("create_icmpULT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::ult, lhs, - rhs); - }) - .def("create_icmpUGE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::uge, lhs, - rhs); - }) - .def("create_icmpUGT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::ugt, lhs, - rhs); - }) - .def("create_icmpEQ", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::eq, lhs, - rhs); - }) - .def("create_icmpNE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpIPredicate::ne, lhs, - rhs); - }) - // Comparison (float) - .def("create_fcmpOLT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::OLT, lhs, - rhs); - }) - .def("create_fcmpOGT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::OGT, lhs, - rhs); - }) - .def("create_fcmpOLE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::OLE, lhs, - rhs); - }) - .def("create_fcmpOGE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::OGE, lhs, - rhs); - }) - .def("create_fcmpOEQ", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::OEQ, lhs, - rhs); - }) - .def("create_fcmpONE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::ONE, lhs, - rhs); - }) - .def("create_fcmpULT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::ULT, lhs, - rhs); - }) - .def("create_fcmpUGT", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::UGT, lhs, - rhs); - }) - .def("create_fcmpULE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::ULE, lhs, - rhs); - }) - .def("create_fcmpUGE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::UGE, lhs, - rhs); - }) - .def("create_fcmpUEQ", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::UEQ, lhs, - rhs); - }) - .def("create_fcmpUNE", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(arith::CmpFPredicate::UNE, lhs, - rhs); - }) - // // Logical - .def("create_and", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_xor", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - .def("create_or", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - return self.create(lhs, rhs); - }) - // Input/Output - .def("create_load", - [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy, bool isVolatile) -> Value { - return self.create(ptrs, cacheModifier, evictionPolicy, - isVolatile); - }) - .def("create_store", - [](TritonOpBuilder &self, Value &ptrs, Value &value, - CacheModifier cacheModifier, - EvictionPolicy evictionPolicy) -> void { - self.create(ptrs, value, cacheModifier, evictionPolicy); - }) - .def("create_tensor_pointer_load", - [](TritonOpBuilder &self, Value &ptr, - std::vector &boundaryCheck, - std::optional paddingOption, - CacheModifier cacheModifier, EvictionPolicy evictionPolicy, - bool isVolatile) -> Value { - return self.create(ptr, boundaryCheck, paddingOption, - cacheModifier, evictionPolicy, - isVolatile); - }) - .def("create_tensor_pointer_store", - [](TritonOpBuilder &self, Value &ptr, Value &val, - std::vector &boundaryCheck, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy) -> void { - self.create(ptr, val, boundaryCheck, cacheModifier, - evictionPolicy); - }) - .def("create_masked_load", - [](TritonOpBuilder &self, Value &ptrs, Value &mask, - std::optional &other, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy, bool isVolatile) -> Value { - return self.create(ptrs, mask, other.value_or(Value()), - cacheModifier, evictionPolicy, - isVolatile); - }) - .def("create_masked_store", - [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, - CacheModifier cacheModifier, - EvictionPolicy evictionPolicy) -> void { - self.create(ptrs, val, mask, cacheModifier, - evictionPolicy); - }) - .def("create_descriptor_load", - [](TritonOpBuilder &self, Value desc_ptr, - std::vector &indices, Type type, - CacheModifier cacheModifier, - EvictionPolicy evictionPolicy) -> Value { - return self.create( - type, desc_ptr, indices, cacheModifier, evictionPolicy); - }) - .def("create_descriptor_store", - [](TritonOpBuilder &self, Value desc_ptr, Value value, - std::vector &indices) -> void { - self.create(desc_ptr, value, - indices); - }) - .def("create_tensormap_create", - [](TritonOpBuilder &self, Value desc_ptr, Value global_address, - std::vector box_dim, std::vector global_dim, - std::vector global_stride, - std::vector element_stride, int32_t elem_type, - int32_t interleave_layout, int32_t swizzle_mode, - int32_t fill_mode) { - self.create( - desc_ptr, global_address, box_dim, global_dim, global_stride, - element_stride, elem_type, interleave_layout, swizzle_mode, - fill_mode); - }) - .def("create_tensormap_fenceproxy_acquire", - [](TritonOpBuilder &self, Value desc_ptr) { - self.create(desc_ptr); - }) - .def("create_reshape", - [](TritonOpBuilder &self, Value &arg, std::vector &shape, - bool allowReorder) -> Value { - auto argType = - cast(arg.getType()).getElementType(); - return self.create( - RankedTensorType::get(shape, argType), arg, allowReorder); - }) - .def("create_expand_dims", - [](TritonOpBuilder &self, Value &arg, int axis) -> Value { - auto argType = dyn_cast(arg.getType()); - auto argEltType = argType.getElementType(); - std::vector retShape = argType.getShape(); - retShape.insert(retShape.begin() + axis, 1); - return self.create( - RankedTensorType::get(retShape, argEltType), arg, axis); - }) - .def("create_cat", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - auto lhsType = dyn_cast(lhs.getType()); - auto rhsType = dyn_cast(rhs.getType()); - if (!(lhsType.getShape().size() == 1 && - rhsType.getShape().size() == 1)) - throw std::invalid_argument( - "shape not supported by cat. Expecting rank-1 inputs"); - std::vector shape{lhsType.getShape()[0] + - rhsType.getShape()[0]}; - return self.create( - RankedTensorType::get(shape, lhsType.getElementType()), lhs, - rhs); - }) - .def("create_join", - [](TritonOpBuilder &self, Value &a, Value &b) -> Value { - return self.create(a, b); - }) - .def("create_split", - [](TritonOpBuilder &self, Value &a) -> std::vector { - auto op = self.create(a); - return std::vector(op->result_begin(), op->result_end()); - }) - .def("create_slice", - [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, - std::vector &sizs_vec, std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - auto ret = self.create(retTy, ful, offsets, - sizes, strides); - return ret; - }) - .def("create_insert", - [](TritonOpBuilder &self, Value &ful, Value &sub, - std::vector &offs_vec, std::vector &sizs_vec, - std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - auto ret = self.create(sub, ful, offsets, - sizes, strides); - return ret; - }) - // Implements tl.trans and tl.permute. - .def("create_trans", - [](TritonOpBuilder &self, Value &arg, - std::vector &order) -> Value { - auto argType = dyn_cast(arg.getType()); - auto argEltType = argType.getElementType(); - auto retShape = applyPermutation(argType.getShape(), order); - return self.create( - RankedTensorType::get(retShape, argEltType), arg, order); - }) - .def("create_broadcast", - [](TritonOpBuilder &self, Value &arg, - std::vector &shape) -> Value { - if (auto argType = dyn_cast(arg.getType())) - return self.createOrFold( - RankedTensorType::get(shape, argType.getElementType()), arg); - throw std::invalid_argument( - "arg is not of RankedTensorType, use create_splat"); - }) - .def("create_splat", - [](TritonOpBuilder &self, Value &arg, - std::vector &shape) -> Value { - auto argType = arg.getType(); - auto ret = self.createOrFold( - RankedTensorType::get(shape, argType), arg); - return ret; - }) - // // atomic - .def("create_atomic_cas", - [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, - MemSemantic sem, MemSyncScope scope) -> Value { - Type dstType; - if (auto srcTensorType = - dyn_cast(ptr.getType())) { - Type dstElemType = - cast(srcTensorType.getElementType()) - .getPointeeType(); - dstType = - RankedTensorType::get(srcTensorType.getShape(), dstElemType); - } else { - auto ptrType = cast(getElementTypeOrSelf(ptr)); - dstType = ptrType.getPointeeType(); - } - return self.create(dstType, ptr, cmp, val, sem, - scope); - }) - .def("create_atomic_rmw", - [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, - Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { - Type dstType; - if (auto srcTensorType = - dyn_cast(ptr.getType())) { - Type dstElemType = - cast(srcTensorType.getElementType()) - .getPointeeType(); - dstType = - RankedTensorType::get(srcTensorType.getShape(), dstElemType); - } else { - auto ptrType = cast(getElementTypeOrSelf(ptr)); - dstType = ptrType.getPointeeType(); - } - return self.create(dstType, rmwOp, ptr, val, mask, - sem, scope); - }) - // External - .def("create_extern_elementwise", - [](TritonOpBuilder &self, const std::string &libName, - const std::string &libPath, const std::string &symbol, - std::vector &argList, Type retType, bool isPure) -> Value { - return self.create(retType, argList, libName, - libPath, symbol, isPure); - }) - // Built-in instruction - .def("create_get_program_id", - [](TritonOpBuilder &self, int axis) -> Value { - if (axis < 0 || axis > 3) - throw pybind11::index_error("program_id must be in [0,3]"); - return self.create(axis); - }) - .def("create_get_num_programs", - [](TritonOpBuilder &self, int axis) -> Value { - if (axis < 0 || axis > 3) - throw pybind11::index_error("program_id must be in [0,3]"); - return self.create(axis); - }) - .def("create_dot", - [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, InputPrecision inputPrecision, - int maxNumImpreciseAcc) -> mlir::Value { - return self.create(c.getType(), a, b, c, inputPrecision, - maxNumImpreciseAcc); - }) - .def("create_dot_scaled", - [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, - F8F6F4Type lhs_format, mlir::Value &rhs, - std::optional &rhs_scale, F8F6F4Type rhs_format, - mlir::Value &c) -> mlir::Value { - return self.create( - c.getType(), lhs, rhs, c, lhs_scale, - rhs_scale.value_or(Value()), lhs_format, rhs_format); - }) - .def("create_floor", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_ceil", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_exp", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_exp2", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_cos", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_sin", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_log", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_log2", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_erf", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_tanh", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_sqrt", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_rsqrt", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_fabs", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_iabs", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - .def("create_reduce", - [](TritonOpBuilder &self, std::vector operands, int axis) - -> OpState { return self.create(operands, axis); }) - .def("create_reduce_ret", - [](TritonOpBuilder &self, py::args args) -> OpState { - llvm::SmallVector return_values; - for (const auto &arg : args) { - return_values.push_back(py::cast(arg)); - } - return self.create(return_values); - }) - .def("create_scan", - [](TritonOpBuilder &self, std::vector operands, int axis, - bool reverse) -> OpState { - return self.create(operands, axis, reverse); - }) - .def("create_scan_ret", - [](TritonOpBuilder &self, py::args args) -> OpState { - llvm::SmallVector return_values; - for (const auto &arg : args) { - return_values.push_back(py::cast(arg)); - } - return self.create(return_values); - }) - .def("create_ptr_to_int", - [](TritonOpBuilder &self, Value &val, Type &type) -> Value { - return self.create(type, val); - }) - .def("create_int_to_ptr", - [](TritonOpBuilder &self, Value &val, Type &type) -> Value { - return self.create(type, val); - }) - .def("create_select", - [](TritonOpBuilder &self, Value &condition, Value &trueValue, - Value &falseValue) -> Value { - return self.create(condition, trueValue, - falseValue); - }) - .def("create_inline_asm", - [](TritonOpBuilder &self, const std::string &inlineAsm, - const std::string &constraints, const std::vector &values, - const std::vector &types, bool isPure, - int pack) -> OpState { - return self.create( - types, inlineAsm, constraints, isPure, pack, values); - }) - .def("create_print", - [](TritonOpBuilder &self, const std::string &prefix, bool hex, - const std::vector &values, - const std::vector &isSigned) -> void { - auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(prefix)); - self.create(prefixAttr, hex, values, isSigned); - }) - .def("create_assert", - [](TritonOpBuilder &self, Value &condition, - const std::string &message) -> void { - auto messageAttr = StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(message)); - self.create(condition, messageAttr); - }) - .def("create_assume", - [](TritonOpBuilder &self, Value &condition) { - self.create(condition); - }) - .def("create_poison", - [](TritonOpBuilder &self, Type &type) -> Value { - return self.create(type); - }) - .def("create_histogram", - [](TritonOpBuilder &self, Value operand, int numBins) -> Value { - return self.create( - RankedTensorType::get( - {static_cast(numBins)}, - IntegerType::get(operand.getContext(), 32)), - operand); - }) - .def("create_gather", - [](TritonOpBuilder &self, Value src, Value indices, int axis) - -> Value { return self.create(src, indices, axis); }) - // Force GPU barrier - .def("create_barrier", - [](TritonOpBuilder &self) { self.create(); }) - // Make a block pointer (tensor pointer in Triton IR) - .def("create_make_block_ptr", - [](TritonOpBuilder &self, Value &base, std::vector &shape, - std::vector &strides, std::vector &offsets, - std::vector &tensorShape, - std::vector &order) -> Value { - return self.create(base, shape, strides, offsets, - tensorShape, order); - }) - // Advance a block pointer - .def("create_advance", - [](TritonOpBuilder &self, Value &ptr, - std::vector &offsets) -> Value { - return self.create(ptr.getType(), ptr, offsets); - }); - - py::class_(m, "pass_manager", py::module_local()) - .def(py::init()) - .def("enable_debug", - [](PassManager &self) { - auto *context = self.getContext(); - bool haveDiagnostics = - ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); - bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); - std::string funcToDump; - if (!haveDump) { - funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); - if (!funcToDump.empty()) - haveDump = true; - } - if (haveDiagnostics || haveDump) { - context->disableMultithreading(); - } - if (haveDiagnostics) { - context->printOpOnDiagnostic(true); - context->printStackTraceOnDiagnostic(true); - context->getDiagEngine().registerHandler([](Diagnostic &diag) { - llvm::outs() << diag << "\n"; - return success(); - }); - } - if (haveDump) { - auto printingFlags = OpPrintingFlags(); - printingFlags.elideLargeElementsAttrs(16); - printingFlags.enableDebugInfo(); - auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { - if (funcToDump.empty()) - return true; - if (auto mod = dyn_cast(op)) { - return mod.lookupSymbol(funcToDump); - } - if (auto func = dyn_cast(op)) { - return SymbolTable::getSymbolName(func).getValue() == - funcToDump; - } - - return false; - }; - self.enableIRPrinting( - /*shouldPrintBeforePass=*/printAlways, - /*shouldPrintAfterPass=*/printAlways, - /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false, - /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), - printingFlags); - } - }) - .def("run", [](PassManager &self, ModuleOp &mod) { - // TODO: maybe dump module to file and print error for better - // diagnostics - - auto reproducerPath = - triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); - if (!reproducerPath.empty()) { - auto anchorName = self.getOpAnchorName(); - auto passes = self.getPasses(); - Operation *op = mod.getOperation(); - makeReproducer(anchorName, passes, op, reproducerPath); - } - - if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { - ::llvm::DebugFlag = true; - } - - if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); - !debugOnly.empty()) { - llvm::SmallVector split; - llvm::SmallVector storage; - llvm::SmallVector debugTypes; - - StringRef(debugOnly.c_str()).split(split, ','); - llvm::transform(split, std::back_inserter(debugTypes), - [&storage](StringRef str) { - // StringRefs are not always null-terminated. - // The purpose for this storage pattern is to - // produce a collection of C-strings that are. - storage.push_back(str.str()); - return storage.back().c_str(); - }); - - ::llvm::DebugFlag = true; - using namespace llvm; - setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); - } - - bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); - if (haveTiming) { - self.enableTiming(); - } - - if (failed(self.run(mod.getOperation()))) - throw std::runtime_error("PassManager::run failed"); - }); -} - -void init_triton_env_vars(py::module &m) { - m.def("get_cache_invalidating_env_vars", - []() -> std::map { - std::map ret; - for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { - auto strVal = triton::tools::getStrEnv(envVar); - if (strVal.empty()) - continue; - auto boolV = triton::tools::isEnvValueBool(strVal); - if (boolV.has_value()) - ret[envVar] = boolV.value() ? "true" : "false"; - else - ret[envVar] = strVal; - } - return ret; - }); -} diff --git a/third_party/ascend/triton_patch/python/triton_patch/__init__.py b/third_party/ascend/triton_patch/python/triton_patch/__init__.py deleted file mode 100644 index b2a5aa214..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# import triton -# from .compiler.errors import MLIRCompilationError as ascend_MLIRCompilationError -# triton.compiler.errors.MLIRCompilationError = ascend_MLIRCompilationError -# from .language._utils import validate_block_shape as ascend_validate_block_shape -# triton.language._utils.validate_block_shape = ascend_validate_block_shape diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py deleted file mode 100644 index d9cb07d64..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py +++ /dev/null @@ -1,1303 +0,0 @@ -import ast -import inspect -import re -import sys -import warnings -import os -import textwrap -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -from triton import language -from triton._C.libtriton import ir -from triton.language import constexpr, tensor, str_to_ty -from triton.language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value -from triton.runtime.jit import _normalize_ty, get_jit_fn_file_line -# ideally we wouldn't need any runtime component -from triton.runtime import JITFunction -from triton.compiler.errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) -from types import ModuleType - - -def mangle_ty(ty): - if ty.is_ptr(): - return 'P' + mangle_ty(ty.element_ty) - if ty.is_int(): - SIGNED = language.dtype.SIGNEDNESS.SIGNED - prefix = 'i' if ty.int_signedness == SIGNED else 'u' - return prefix + str(ty.int_bitwidth) - if ty.is_floating(): - return str(ty) - if ty.is_block(): - elt = mangle_ty(ty.scalar) - shape = '_'.join(map(str, ty.shape)) - return f'{elt}S{shape}S' - if ty.is_void(): - return 'V' - raise TypeError(f'Unsupported type {ty}') - - -def mangle_fn(name, arg_tys, constants): - # doesn't mangle ret type, which must be a function of arg tys - mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) - mangled_constants = mangled_constants.replace('.', '_d_') - mangled_constants = mangled_constants.replace("'", '_sq_') - # [ and ] are not allowed in LLVM identifiers - mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') - ret = f'{name}__{mangled_arg_names}__{mangled_constants}' - return ret - - -def _is_triton_value(o: Any) -> bool: - return isinstance(o, _value) - - -def _is_triton_tensor(o: Any) -> bool: - return isinstance(o, tensor) - - -def _is_constexpr(o: Any) -> bool: - return isinstance(o, constexpr) - - -def _is_triton_scalar(o: Any) -> bool: - return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) - - -def _is_list_like(o: Any) -> bool: - return isinstance(o, (list, tuple)) - - -def _check_fn_args(node, fn, args): - if fn.noinline: - for idx, arg in enumerate(args): - if not _is_constexpr(arg) and not _is_triton_scalar(arg): - raise UnsupportedLanguageConstruct( - fn.src, node, - f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' - ) - - -_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels - - -class enter_sub_region: - - def __init__(self, generator): - self.generator = generator - - def __enter__(self): - # record lscope & local_defs in the parent scope - self.liveins = self.generator.lscope.copy() - self.prev_defs = self.generator.local_defs.copy() - self.generator.local_defs = {} - self.insert_block = self.generator.builder.get_insertion_block() - self.insert_point = self.generator.builder.get_insertion_point() - return self.liveins, self.insert_block - - def __exit__(self, *args, **kwargs): - self.generator.builder.restore_insertion_point(self.insert_point) - self.generator.lscope = self.liveins - self.generator.local_defs = self.prev_defs - - -# Check if the given syntax node has an "early" return -class ContainsReturnChecker(ast.NodeVisitor): - - def __init__(self, gscope): - self.gscope = gscope - - def _visit_stmts(self, body) -> bool: - return any(self.visit(s) for s in body) - - def _visit_function(self, fn) -> bool: - # Currently we only support JITFunctions defined in the global scope - if isinstance(fn, JITFunction) and not fn.noinline: - fn_node = fn.parse() - return ContainsReturnChecker(self.gscope).visit(fn_node) - return False - - def generic_visit(self, node) -> bool: - ret = False - for _, value in ast.iter_fields(node): - if isinstance(value, list): - for item in value: - if isinstance(item, ast.AST): - ret = ret or self.visit(item) - elif isinstance(value, ast.AST): - ret = ret or self.visit(value) - return ret - - def visit_Attribute(self, node: ast.Attribute) -> bool: - # If the left part is a name, it's possible that - # we call triton native function or a jit function from another module. - # If the left part is not a name, it must return a tensor or a constexpr - # whose methods do not contain return statements - # e.g., (tl.load(x)).to(y) - # So we only check if the expressions within value have return or not - if isinstance(node.value, ast.Name): - if node.value.id in self.gscope: - value = self.gscope[node.value.id] - fn = getattr(value, node.attr) - return self._visit_function(fn) - return False - return self.visit(node.value) - - def visit_Name(self, node: ast.Name) -> bool: - if type(node.ctx) is ast.Store: - return False - if node.id in self.gscope: - fn = self.gscope[node.id] - return self._visit_function(fn) - return False - - def visit_Return(self, node: ast.Return) -> bool: - return True - - def visit_Assign(self, node: ast.Assign) -> bool: - # There couldn't be an early return - # x = ... - return False - - def visit_AugAssign(self, node: ast.AugAssign) -> bool: - # There couldn't be an early return - # x += ... - return False - - def visit_Module(self, node: ast.Module) -> bool: - return self._visit_stmts(node.body) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: - return self._visit_stmts(node.body) - - def visit_If(self, node: ast.If) -> bool: - # TODO: optimize the following case in which we actually don't have - # a return when static_cond is false: - # if dynamic_cond - # if static_cond - # func_with_return - # else - # func_without_return - ret = self._visit_stmts(node.body) - if node.orelse: - ret = ret or self._visit_stmts(node.orelse) - return ret - - def visit_IfExp(self, node: ast.IfExp) -> bool: - return self.visit(node.body) or self.visit(node.orelse) - - def visit_Call(self, node: ast.Call) -> bool: - return self.visit(node.func) - - -class CodeGenerator(ast.NodeVisitor): - - def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, - noinline=False, file_name: Optional[str] = None, begin_line=0): - self.context = context - self.builder = ir.builder(context) - self.file_name = file_name - # node.lineno starts from 1, so we need to subtract 1 - self.begin_line = begin_line - 1 - self.builder.set_loc(file_name, begin_line, 0) - self.builder.options = options - # dict of functions provided by the backend. Below are the list of possible functions: - # Convert custom types not natively supported on HW. - # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) - self.builder.codegen_fns = codegen_fns - self.builder.module_map = {} if module_map is None else module_map - self.module = self.builder.create_module() if module is None else module - self.function_ret_types = {} if function_types is None else function_types - self.prototype = prototype - - self.gscope = {} - for k, v in gscope.items(): - if isinstance(v, ModuleType): - self.gscope[k] = module_map.get(v.__name__, v) - continue - - module_name = getattr(v, "__module__", "") - if module_name in module_map: - self.gscope[k] = getattr(module_map[module_name], v.__name__) - else: - self.gscope[k] = v - - self.lscope = {} - self.attributes = attributes - self.constants = constants - self.jit_fn = jit_fn - self.function_name = function_name - self.is_kernel = is_kernel - self.cur_node = None - self.noinline = noinline - self.scf_stack = [] - self.ret_type = None - # SSA-construction - # name => language.tensor - self.local_defs: Dict[str, tensor] = {} - self.dereference_name: Callable[[str], Any] = self._define_name_lookup() - self.fn = None - # Are we currently visiting an ast.arg's default value? These have some - # special handling. - self.visiting_arg_default_value = False - - builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} - builtin_namespace.update(( - ('print', language.core.device_print), - ('min', language.minimum), - ('max', language.maximum), - )) - - def _unsupported(self, node, message): - return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) - - def _is_constexpr_global(self, name): - absent_marker = object() - val = self.gscope.get(name, absent_marker) - if val is absent_marker: - return False - - if _is_constexpr(val): - return True - - if a := self.gscope.get("__annotations__", {}).get(name): - return _normalize_ty(a) == "constexpr" - - return False - - def _define_name_lookup(self): - - def local_lookup(name: str, absent): - # this needs to be re-fetched from `self` every time, because it gets switched occasionally - return self.lscope.get(name, absent) - - def global_lookup(name: str, absent): - val = self.gscope.get(name, absent) - # The high-level rule is that only constexpr globals are allowed. - # But actually a bunch of other things, such as module imports, are - # technically Python globals. We have to allow these too! - if any([ - val is absent, name in self.builtin_namespace, # - type(val) is ModuleType, # - isinstance(val, JITFunction), # - getattr(val, "__triton_builtin__", False), # - getattr(val, "__module__", "").startswith("triton.language"), # - isinstance(val, language.dtype), # - self._is_constexpr_global(name), # - # Allow accesses to globals while visiting an ast.arg - # because you should be able to do - # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... - self.visiting_arg_default_value, # - os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" - ]): - return val - raise NameError( - textwrap.dedent(f"""\ - Cannot access global variable {name} from within @jit'ed - function. Triton kernels can only access global variables that - are annotated as constexpr (`x: triton.language.constexpr = 42` - or `x = triton.language.constexpr(42)`). Alternatively, set the - envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not - promise to support this forever.""").replace("\n", " ")) - - absent_marker = object() - - def name_lookup(name: str) -> Any: - absent = absent_marker - for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: - value = lookup_function(name, absent) - if value is not absent: - return value - raise NameError(f'{name} is not defined') - - return name_lookup - - def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: - ''' This function: - called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) - 1. record local defined name (FIXME: should consider control flow) - 2. store tensor in self.lvalue - ''' - self.lscope[name] = value - self.local_defs[name] = value - - def _get_insertion_point_and_loc(self): - # XXX: this is a hack to get the location of the insertion point. - # The insertion point's location could be invalid sometimes, - # so we need to explicitly set the location - loc = self.builder.get_loc() - ip = self.builder.get_insertion_point() - return ip, loc - - def _set_insertion_point_and_loc(self, ip, loc): - self.builder.restore_insertion_point(ip) - self.builder.set_loc(loc) - - # - # AST visitor - # - def visit_compound_statement(self, stmts): - # Ensure that stmts is iterable - if not _is_list_like(stmts): - stmts = [stmts] - for stmt in stmts: - self.visit(stmt) - - # Stop parsing as soon as we hit a `return` statement; everything - # after this is dead code. - if isinstance(stmt, ast.Return): - break - - def visit_Module(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_List(self, node): - ctx = self.visit(node.ctx) - assert ctx is None - elts = [self.visit(elt) for elt in node.elts] - return elts - - # By design, only non-kernel functions can return - def visit_Return(self, node): - ret_value = self.visit(node.value) - if ret_value is None: - self.builder.ret([]) - ret_ty = language.void - elif isinstance(ret_value, tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] - ret_types = [v.type for v in ret_values] - self.builder.ret([v.handle for v in ret_values]) - ret_ty = tuple(ret_types) - else: - ret = language.semantic.to_tensor(ret_value, self.builder) - self.builder.ret([ret.handle]) - ret_ty = ret.type - - if self.ret_type is None: - self.ret_type = ret_ty - elif self.ret_type != ret_ty: - raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') - - # A return op must always terminate the basic block, so we create a dead - # basic block in case there are any ops after the return. - post_ret_block = self.builder.create_block() - self.builder.set_insertion_point_to_end(post_ret_block) - - def visit_FunctionDef(self, node): - arg_names, kwarg_names = self.visit(node.args) - if self.fn: - raise self._unsupported(node, "nested function definition is not supported.") - # initialize defaults - for i, default_value in enumerate(node.args.defaults[::-1]): - arg_node = node.args.args[-i - 1] - annotation = arg_node.annotation - name = arg_node.arg - st_target = ast.Name(id=name, ctx=ast.Store()) - if annotation is None: - init_node = ast.Assign(targets=[st_target], value=default_value) - else: - init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - - try: - assert not self.visiting_arg_default_value - self.visiting_arg_default_value = True - self.visit(init_node) - finally: - self.visiting_arg_default_value = False - - # initialize function - visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, - self.prototype.to_ir(self.builder), visibility, self.noinline) - self.module.push_back(self.fn) - entry = self.fn.add_entry_block() - arg_values = [] - idx = 0 - for i in range(len(arg_names)): - if i in self.constants: - cst = self.constants[i] - if not _is_constexpr(cst): - cst = constexpr(self.constants[i]) - arg_values.append(cst) - continue - else: - if i in self.attributes: - for name, value in self.attributes[i]: - self.fn.set_arg_attr(idx, name, value) - - # Mark this argument as a pass-by-value TMA descriptor (nvidia) - if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): - self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) - - arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() - for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) - self.builder.set_insertion_point_to_start(entry) - # visit function body - self.visit_compound_statement(node.body) - - # finalize function - assert not self.builder.get_insertion_block().has_terminator() - if self.ret_type is None or self.ret_type == language.void: - self.ret_type = language.void - self.builder.ret([]) - else: - self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) - self.builder.ret([ - self.builder.create_poison(ty.to_ir(self.builder)) - for ty in self.prototype.ret_types - if self.ret_type is not None - ]) - self.fn.finalize() - - if insert_pt: - self.builder.set_insertion_point_to_end(insert_pt) - - def visit_arguments(self, node): - arg_names = [] - for arg in node.args: - arg_names += [self.visit(arg)] - kwarg_names = self.visit(node.kwarg) - return arg_names, kwarg_names - - def visit_arg(self, node): - ast.NodeVisitor.generic_visit(self, node) - return node.arg - - def visit_AnnAssign(self, node): - # extract attributes - annotation = self.visit(node.annotation) - target = self.visit(node.target) - value = self.visit(node.value) - # constexpr - if annotation == constexpr: - if target in self.lscope: - raise ValueError(f'{target} is already defined.' - f' constexpr cannot be reassigned.') - if not _is_constexpr(value): - value = constexpr(value) - self.lscope[target] = value - return self.lscope[target] - # default: call visit_Assign - return self.visit_Assign(node) - - def visit_Assign(self, node): - _names = [] - if isinstance(node, ast.AnnAssign): - _names += [self.visit(node.target)] - else: - for target in node.targets: - _names += [self.visit(target)] - if len(_names) > 1: - raise self._unsupported(node, "simultaneous multiple assignment is not supported.") - names = _names[0] - values = self.visit(node.value) - if not _is_list_like(names): - names = [names] - if not _is_list_like(values): - values = [values] - native_nontensor_types = (language.dtype, ) - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable - value = _unwrap_if_constexpr(value) - if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): - value = language.semantic.to_tensor(value, self.builder) - self.set_value(name, value) - - def visit_AugAssign(self, node): - name = node.target.id - lhs = ast.Name(id=name, ctx=ast.Load()) - rhs = ast.BinOp(lhs, node.op, node.value) - assign = ast.Assign(targets=[node.target], value=rhs) - self.visit(assign) - return self.dereference_name(name) - - def visit_Name(self, node): - if type(node.ctx) is ast.Store: - return node.id - return self.dereference_name(node.id) - - def visit_Store(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_Load(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_Tuple(self, node): - args = [self.visit(x) for x in node.elts] - return tuple(args) - - def _apply_binary_method(self, method_name, lhs, rhs): - # TODO: raise something meaningful if getattr fails below, esp for reverse method - if _is_triton_tensor(lhs): - return getattr(lhs, method_name)(rhs, _builder=self.builder) - if _is_triton_tensor(rhs): - reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) - return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) - return getattr(lhs, method_name)(rhs) - - def visit_BinOp(self, node): - lhs = self.visit(node.left) - rhs = self.visit(node.right) - method_name = self._method_name_for_bin_op.get(type(node.op)) - if method_name is None: - raise self._unsupported(node, - "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) - return self._apply_binary_method(method_name, lhs, rhs) - - _method_name_for_bin_op: Dict[Type[ast.operator], str] = { - ast.Add: '__add__', - ast.Sub: '__sub__', - ast.Mult: '__mul__', - ast.Div: '__truediv__', - ast.FloorDiv: '__floordiv__', - ast.Mod: '__mod__', - ast.Pow: '__pow__', - ast.LShift: '__lshift__', - ast.RShift: '__rshift__', - ast.BitAnd: '__and__', - ast.BitOr: '__or__', - ast.BitXor: '__xor__', - } - - def visit_then_else_blocks(self, node, liveins, then_block, else_block): - # then block - self.builder.set_insertion_point_to_start(then_block) - self.visit_compound_statement(node.body) - then_block = self.builder.get_insertion_block() - then_defs = self.local_defs.copy() - # else block - else_defs = {} - if node.orelse: - self.builder.set_insertion_point_to_start(else_block) - self.lscope = liveins.copy() - self.local_defs = {} - self.visit_compound_statement(node.orelse) - else_defs = self.local_defs.copy() - else_block = self.builder.get_insertion_block() - - # update block arguments - names = [] - ret_types = [] - ir_ret_types = [] - # variables in livein whose value is updated in `if` - for name in liveins: - # check type - for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: - if name in defs: - assert defs[name].type == liveins[name].type, \ - f'initial value for `{name}` is of type {liveins[name].type}, '\ - f'but the {block_name} block redefines it as {defs[name].type}' - if name in then_defs or name in else_defs: - names.append(name) - ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) - ir_ret_types.append(then_defs[name].handle.get_type() if name in - then_defs else else_defs[name].handle.get_type()) - # variable defined in then but not in else - if name in then_defs and name not in else_defs: - else_defs[name] = liveins[name] - # variable defined in else but not in then - if name in else_defs and name not in then_defs: - then_defs[name] = liveins[name] - # variables that are both in then and else but not in liveins - # TODO: could probably be cleaned up - for name in sorted(then_defs.keys() & else_defs.keys()): - if name in names: - continue - then_ty = then_defs[name].type - else_ty = else_defs[name].type - assert then_ty == else_ty, \ - f'Mismatched type for {name} between then block ({then_ty}) '\ - f'and else block ({else_ty})' - names.append(name) - ret_types.append(then_ty) - ir_ret_types.append(then_defs[name].handle.get_type()) - - return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types - - def visit_if_top_level(self, cond, node): - with enter_sub_region(self) as sr: - liveins, ip_block = sr - then_block = self.builder.create_block() - else_block = self.builder.create_block() - # create branch - self.builder.set_insertion_point_to_end(ip_block) - self.builder.create_cond_branch(cond.handle, then_block, else_block) - # visit then and else blocks - then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ - self.visit_then_else_blocks(node, liveins, then_block, else_block) - # create basic-block after conditional - endif_block = self.builder.create_block() - # then terminator - self.builder.set_insertion_point_to_end(then_block) - assert not then_block.has_terminator(), f"{then_block}" - self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) - # else terminator - self.builder.set_insertion_point_to_end(else_block) - assert not else_block.has_terminator(), f"{else_block}" - self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - for ty in ir_ret_types: - endif_block.add_argument(ty) - - # change block - self.builder.set_insertion_point_to_start(endif_block) - # update value - for i, name in enumerate(names): - new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) - self.set_value(name, new_tensor) - - # TODO: refactor - def visit_if_scf(self, cond, node): - with enter_sub_region(self) as sr: - liveins, _ = sr - ip, last_loc = self._get_insertion_point_and_loc() - then_block = self.builder.create_block() - else_block = self.builder.create_block() if node.orelse else None - then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ - self.visit_then_else_blocks(node, liveins, then_block, else_block) - # create if op - self._set_insertion_point_and_loc(ip, last_loc) - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) - then_block.merge_block_before(if_op.get_then_block()) - self.builder.set_insertion_point_to_end(if_op.get_then_block()) - if len(names) > 0: - self.builder.create_yield_op([then_defs[n].handle for n in names]) - if not node.orelse: - else_block = if_op.get_else_block() - else: - else_block.merge_block_before(if_op.get_else_block()) - self.builder.set_insertion_point_to_end(if_op.get_else_block()) - if len(names) > 0: - self.builder.create_yield_op([else_defs[n].handle for n in names]) - # update values - for i, name in enumerate(names): - new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) - self.set_value(name, new_tensor) - - def visit_If(self, node): - cond = self.visit(node.test) - - if _is_triton_tensor(cond): - cond = cond.to(language.int1, _builder=self.builder) - contains_return = ContainsReturnChecker(self.gscope).visit(node) - if contains_return: - if self.scf_stack: - raise self._unsupported( - node, "Cannot have `return` statements inside `while` or `for` statements in triton " - "(note that this also applies to `return` statements that are inside functions " - "transitively called from within `while`/`for` statements)") - self.visit_if_top_level(cond, node) - else: - self.visit_if_scf(cond, node) - else: - cond = _unwrap_if_constexpr(cond) - # not isinstance - we insist the real thing, no subclasses and no ducks - if type(cond) not in _condition_types: - raise self._unsupported( - node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), - type(cond).__name__)) - - active_block = node.body if cond else node.orelse - self.visit_compound_statement(active_block) - - def visit_IfExp(self, node): - cond = self.visit(node.test) - if _is_triton_tensor(cond): - cond = cond.to(language.int1, _builder=self.builder) - # TODO: Deal w/ more complicated return types (e.g tuple) - with enter_sub_region(self): - ip, last_loc = self._get_insertion_point_and_loc() - - then_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(then_block) - then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) - then_block = self.builder.get_insertion_block() - - else_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(else_block) - # do not need to reset lscope since - # ternary expressions cannot define new variables - else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) - else_block = self.builder.get_insertion_block() - - self._set_insertion_point_and_loc(ip, last_loc) - - assert then_val.type == else_val.type, \ - f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' - ret_type = then_val.type - - ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] - if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) - then_block.merge_block_before(if_op.get_then_block()) - if ret_type_ir: - self.builder.set_insertion_point_to_end(if_op.get_then_block()) - self.builder.create_yield_op([then_val.handle]) - - self.builder.set_insertion_point_to_end(if_op.get_then_block()) - else_block.merge_block_before(if_op.get_else_block()) - if ret_type_ir: - self.builder.set_insertion_point_to_end(if_op.get_else_block()) - self.builder.create_yield_op([else_val.handle]) - return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None - else: - cond = _unwrap_if_constexpr(cond) - - # not isinstance - we insist the real thing, no subclasses and no ducks - if type(cond) not in _condition_types: - raise self._unsupported( - node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), - type(cond).__name__)) - if cond: - return self.visit(node.body) - else: - return self.visit(node.orelse) - - def visit_Pass(self, node): - pass - - def visit_Compare(self, node): - if not (len(node.comparators) == 1 and len(node.ops) == 1): - raise self._unsupported(node, "simultaneous multiple comparison is not supported") - lhs = self.visit(node.left) - rhs = self.visit(node.comparators[0]) - lhs_value = _unwrap_if_constexpr(lhs) - rhs_value = _unwrap_if_constexpr(rhs) - if type(node.ops[0]) is ast.Is: - return constexpr(lhs_value is rhs_value) - if type(node.ops[0]) is ast.IsNot: - return constexpr(lhs_value is not rhs_value) - method_name = self._method_name_for_comp_op.get(type(node.ops[0])) - if method_name is None: - raise self._unsupported( - node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) - return self._apply_binary_method(method_name, lhs, rhs) - - _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { - ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' - } - - def visit_UnaryOp(self, node): - operand = self.visit(node.operand) - fn = self._method_name_for_unary_op.get(type(node.op)) - if fn is None: - raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") - if _is_triton_tensor(operand): - return getattr(operand, fn)(_builder=self.builder) - try: - return getattr(operand, fn)() - except AttributeError: - raise self._unsupported( - node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") - - _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { - ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' - } - - def _verify_loop_carried_variable(self, name, loop_val, live_val): - assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' - assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' - assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' - assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ - f'Loop-carried variable {name} has initial type {live_val.type} '\ - f'but is re-assigned to {loop_val.type} in loop! '\ - f'Please make sure that the type stays consistent.' - - def visit_While(self, node): - with enter_sub_region(self) as sr: - liveins, insert_block = sr - ip, last_loc = self._get_insertion_point_and_loc() - - # loop body (the after region) - # loop_block = self.builder.create_block() - dummy = self.builder.create_block() - self.builder.set_insertion_point_to_start(dummy) - self.scf_stack.append(node) - self.visit_compound_statement(node.body) - self.scf_stack.pop() - loop_defs = self.local_defs - dummy.erase() - - # collect loop-carried values - names = [] - ret_types = [] - init_args = [] - for name in loop_defs: - if name in liveins: - # We should not def new constexpr - loop_val = loop_defs[name] - live_val = liveins[name] - self._verify_loop_carried_variable(name, loop_val, live_val) - - # these are loop-carried values - names.append(name) - ret_types.append(loop_val.type) - init_args.append(live_val) - - self._set_insertion_point_and_loc(ip, last_loc) - while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], - [arg.handle for arg in init_args]) - # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before(), - [ty.to_ir(self.builder) for ty in ret_types]) - self.builder.set_insertion_point_to_start(before_block) - for i, name in enumerate(names): - self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) - self.local_defs[name] = self.lscope[name] - cond = self.visit(node.test) - self.builder.set_insertion_point_to_end(before_block) - # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... - self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) - # merge the loop body - after_block = self.builder.create_block_with_parent(while_op.get_after(), - [ty.to_ir(self.builder) for ty in ret_types]) - - # generate loop body - self.builder.set_insertion_point_to_start(after_block) - for i, name in enumerate(names): - self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) - self.local_defs[name] = self.lscope[name] - self.scf_stack.append(node) - self.visit_compound_statement(node.body) - self.scf_stack.pop() - loop_defs = self.local_defs - yields = [] - for name in loop_defs: - if name in liveins: - yields.append(loop_defs[name]) - self.builder.create_yield_op([y.handle for y in yields]) - - # WhileOp defines new values, update the symbol table (lscope, local_defs) - for i, name in enumerate(names): - new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) - self.lscope[name] = new_def - self.local_defs[name] = new_def - - for stmt in node.orelse: - assert False, "Not implemented" - ast.NodeVisitor.generic_visit(self, stmt) - - def visit_Subscript(self, node): - assert node.ctx.__class__.__name__ == "Load" - lhs = self.visit(node.value) - slices = self.visit(node.slice) - if _is_triton_tensor(lhs): - return lhs.__getitem__(slices, _builder=self.builder) - return lhs[slices] - - def visit_ExtSlice(self, node): - return [self.visit(dim) for dim in node.dims] - - def visit_For(self, node): - IteratorClass = self.visit(node.iter.func) - iter_args = [self.visit(arg) for arg in node.iter.args] - iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) - if IteratorClass == language.static_range: - iterator = IteratorClass(*iter_args, **iter_kwargs) - static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) - for i in static_range: - self.lscope[node.target.id] = constexpr(i) - self.visit_compound_statement(node.body) - for stmt in node.orelse: - ast.NodeVisitor.generic_visit(self, stmt) - return - num_stages = None - loop_unroll_factor = None - if IteratorClass is language.range: - iterator = IteratorClass(*iter_args, **iter_kwargs) - # visit iterator arguments - # note: only `range` iterator is supported now - # collect lower bound (lb), upper bound (ub), and step - lb = iterator.start - ub = iterator.end - step = iterator.step - num_stages = iterator.num_stages - loop_unroll_factor = iterator.loop_unroll_factor - elif IteratorClass is range: - # visit iterator arguments - # note: only `range` iterator is supported now - # collect lower bound (lb), upper bound (ub), and step - lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) - ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) - step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) - else: - raise RuntimeError('Only `range` and `static_range` iterators are currently supported') - # handle negative constant step (not supported by scf.for in MLIR) - negative_step = False - if _is_constexpr(step) and step.value < 0: - step = constexpr(-step.value) - negative_step = True - lb, ub = ub, lb - lb = language.semantic.to_tensor(lb, self.builder) - ub = language.semantic.to_tensor(ub, self.builder) - step = language.semantic.to_tensor(step, self.builder) - # induction variable type - if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): - raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") - iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) - iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) - iv_ir_type = iv_type.to_ir(self.builder) - iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED - # lb/ub/step might be constexpr, we need to cast them to tensor - lb = lb.handle - ub = ub.handle - step = step.handle - # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index - lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) - ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) - step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) - # Create placeholder for the loop induction variable - iv = self.builder.create_poison(iv_ir_type) - self.set_value(node.target.id, language.core.tensor(iv, iv_type)) - - with enter_sub_region(self) as sr: - liveins, insert_block = sr - ip, last_loc = self._get_insertion_point_and_loc() - - # create loop body block - block = self.builder.create_block() - self.builder.set_insertion_point_to_start(block) - # dry visit loop body - self.scf_stack.append(node) - self.visit_compound_statement(node.body) - self.scf_stack.pop() - block.erase() - - # If a variable (name) is defined in both its parent & itself, then it's - # a loop-carried variable. (They must be of the same type) - init_args = [] - yields = [] - names = [] - for name in self.local_defs: - if name in liveins: - loop_val = self.local_defs[name] - live_val = liveins[name] - self._verify_loop_carried_variable(name, loop_val, live_val) - - names.append(name) - init_args.append(live_val) - yields.append(loop_val) - - # create ForOp - self._set_insertion_point_and_loc(ip, last_loc) - for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) - if num_stages is not None: - for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) - if loop_unroll_factor is not None: - for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) - - self.scf_stack.append(node) - self.builder.set_insertion_point_to_start(for_op.get_body(0)) - # reset local scope to not pick up local defs from the previous dry run. - self.lscope = liveins.copy() - self.local_defs = {} - for i, name in enumerate(names): - self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) - self.visit_compound_statement(node.body) - self.scf_stack.pop() - yields = [] - for name in self.local_defs: - if name in liveins: - yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) - - # create YieldOp - if len(yields) > 0: - self.builder.create_yield_op([y.handle for y in yields]) - for_op_region = for_op.get_body(0).get_parent() - assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" - - # update induction variable with actual value, and replace all uses - self.builder.set_insertion_point_to_start(for_op.get_body(0)) - iv = for_op.get_induction_var() - if negative_step: - iv = self.builder.create_sub(ub, iv) - iv = self.builder.create_add(iv, lb) - self.lscope[node.target.id].handle.replace_all_uses_with(iv) - self.set_value(node.target.id, language.core.tensor(iv, iv_type)) - - # update lscope & local_defs (ForOp defines new values) - for i, name in enumerate(names): - self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) - - for stmt in node.orelse: - assert False, "Don't know what to do with else after for" - ast.NodeVisitor.generic_visit(self, stmt) - - def visit_Slice(self, node): - lower = self.visit(node.lower) - upper = self.visit(node.upper) - step = self.visit(node.step) - return slice(lower, upper, step) - - def visit_Index(self, node): - return self.visit(node.value) - - def visit_keyword(self, node) -> Tuple[str, Any]: - return node.arg, self.visit(node.value) - - def visit_Assert(self, node) -> Any: - test = self.visit(node.test) - msg = self.visit(node.msg) if node.msg is not None else "" - return language.core.device_assert(test, msg, _builder=self.builder) - - def call_JitFunction(self, fn: JITFunction, args, kwargs): - args = inspect.getcallargs(fn.fn, *args, **kwargs) - args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] - # generate function def - attributes = {} - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) - # generate function def if necessary - if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) - gscope = fn.__globals__ - # If the callee is not set, we use the same debug setting as the caller - file_name, begin_line = get_jit_fn_file_line(fn) - generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, - noinline=fn.noinline, file_name=file_name, begin_line=begin_line, - options=self.builder.options, codegen_fns=self.builder.codegen_fns, - module_map=self.builder.module_map) - try: - generator.visit(fn.parse()) - except Exception as e: - # Wrap the error in the callee with the location of the call. - raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e - - callee_ret_type = generator.ret_type - self.function_ret_types[fn_name] = callee_ret_type - else: - callee_ret_type = self.function_ret_types[fn_name] - symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: - return None - elif call_op.get_num_results() == 1: - return tensor(call_op.get_result(0), callee_ret_type) - else: - # should return a tuple of tl.tensor - results = [] - for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) - - def visit_Call(self, node): - fn = _unwrap_if_constexpr(self.visit(node.func)) - static_implementation = self.statically_implemented_functions.get(fn) - if static_implementation is not None: - return static_implementation(self, node) - - kws = dict(self.visit(keyword) for keyword in node.keywords) - args = [self.visit(arg) for arg in node.args] - if isinstance(fn, JITFunction): - _check_fn_args(node, fn, args) - return self.call_JitFunction(fn, args, kws) - if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): - extra_kwargs = {"_builder": self.builder} - sig = inspect.signature(fn) - if '_generator' in sig.parameters: - extra_kwargs['_generator'] = self - try: - return fn(*args, **extra_kwargs, **kws) - except Exception as e: - # Normally when we raise a CompilationError, we raise it as - # `from None`, because the original fileline from the exception - # is not relevant (and often points into code_generator.py - # itself). But when calling a function, we raise as `from e` to - # preserve the traceback of the original error, which may e.g. - # be in core.py. - raise CompilationError(self.jit_fn.src, node, repr(e)) from e - - if fn in self.builtin_namespace.values(): - args = map(_unwrap_if_constexpr, args) - return fn(*args, **kws) - - def visit_Constant(self, node): - return constexpr(node.value) - - def visit_BoolOp(self, node: ast.BoolOp): - if len(node.values) != 2: - raise self._unsupported( - node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") - lhs = self.visit(node.values[0]) - rhs = self.visit(node.values[1]) - method_name = self._method_name_for_bool_op.get(type(node.op)) - if method_name is None: - raise self._unsupported( - node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) - return self._apply_binary_method(method_name, lhs, rhs) - - _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} - - if sys.version_info < (3, 8): - - def visit_NameConstant(self, node): - return constexpr(node.value) - - def visit_Num(self, node): - return constexpr(node.n) - - def visit_Str(self, node): - return constexpr(ast.literal_eval(node)) - - def visit_Attribute(self, node): - lhs = self.visit(node.value) - if _is_triton_tensor(lhs) and node.attr == "T": - return language.semantic.permute(lhs, (1, 0), builder=self.builder) - return getattr(lhs, node.attr) - - def visit_Expr(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_NoneType(self, node): - return None - - def visit_JoinedStr(self, node): - values = list(node.values) - for i, value in enumerate(values): - if isinstance(value, ast.Constant): - values[i] = str(value.value) - elif isinstance(value, ast.FormattedValue): - conversion_code = value.conversion - evaluated = self.visit(value.value) - if not _is_constexpr(evaluated): - raise self._unsupported( - node, - "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " - + str(type(evaluated))) - values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) - else: - raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) - return ''.join(values) - - def visit(self, node): - if node is None: - return - with warnings.catch_warnings(): - # The ast library added visit_Constant and deprecated some other - # methods but we can't move to that without breaking Python 3.6 and 3.7. - warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 - warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 - last_node = self.cur_node - last_loc = self.builder.get_loc() - self.cur_node = node - if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): - self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) - last_loc = self.builder.get_loc() - try: - ret = super().visit(node) - except CompilationError: - raise - except Exception as e: - # Wrap the error in a CompilationError which contains the source - # of the @jit function. - raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None - - # Reset the location to the last one before the visit - if last_loc: - self.cur_node = last_node - self.builder.set_loc(last_loc) - return ret - - def generic_visit(self, node): - raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) - - def execute_static_assert(self, node: ast.Call) -> None: - arg_count = len(node.args) - if not (0 < arg_count <= 2) or len(node.keywords): - raise TypeError("`static_assert` requires one or two positional arguments only") - - passed = _unwrap_if_constexpr(self.visit(node.args[0])) - if not isinstance(passed, bool): - raise NotImplementedError( - "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" - ) - if not passed: - if arg_count == 1: - message = "" - else: - try: - message = self.visit(node.args[1]) - except Exception as e: - message = "" - - raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) - return None - - def static_executor(python_fn): - - def ret(self, node: ast.Call): - kws = { - name: _unwrap_if_constexpr(value) - for name, value in (self.visit(keyword) for keyword in node.keywords) - } - args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] - return constexpr(python_fn(*args, **kws)) - - return ret - - statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { - language.core.static_assert: execute_static_assert, - language.core.static_print: static_executor(print), - int: static_executor(int), - len: static_executor(len), - } - - -def kernel_suffix(signature, specialization): - # suffix format: - # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> - suffix = '' - for i, _ in enumerate(signature): - suffix += str(i) - if i in specialization.equal_to_1: - suffix += 'c' - if i in specialization.divisibility_16: - suffix += 'd' - return suffix - - -def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): - attrs = specialization.attrs - # create kernel prototype - cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in specialization.constants.items()} - # visit kernel AST - gscope = fn.__globals__.copy() - function_name = fn.repr(specialization) - tys = list(specialization.signature.values()) - new_constants = attrs.get_constants() - for k in new_constants: - if k in tys and tys[k] == "i1" and new_constants[k] == 1: - new_constants[k] = True - - new_attrs = attrs.filter_out_constants() - fn_attrs = new_attrs.get_fn_attrs() - all_constants = constants.copy() - all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] - file_name, begin_line = get_jit_fn_file_line(fn) - - prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) - generator.visit(fn.parse()) - - ret = generator.module - # module takes ownership of the context - ret.context = context - return ret diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py deleted file mode 100644 index e368a4b23..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py +++ /dev/null @@ -1,447 +0,0 @@ -from __future__ import annotations -import hashlib -import json -from triton._C.libtriton import get_cache_invalidating_env_vars, ir -from pathlib import Path -import re -import functools -import os - -# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, -# and any following whitespace -# - (public\s+)? : optionally match the keyword public and any following whitespace -# - (@\w+) : match an @ symbol followed by one or more word characters -# (letters, digits, or underscores), and capture it as group 1 (the function name) -# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing -# zero or more arguments separated by commas, and capture it as group 2 (the argument list) -# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 -mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" -ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" -prototype_pattern = { - "ttir": mlir_prototype_pattern, - "ttgir": mlir_prototype_pattern, - "ptx": ptx_prototype_pattern, -} - -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' -ptx_arg_type_pattern = r"\.param\s+\.(\w+)" -arg_type_pattern = { - "ttir": mlir_arg_type_pattern, - "ttgir": mlir_arg_type_pattern, - "ptx": ptx_arg_type_pattern, -} - - -def convert_type_repr(x): - # Currently we only capture the pointer type and assume the pointer is on global memory. - # TODO: Capture and support shared memory space - match = re.search(r'!tt\.ptr<([^,]+)', x) - tma = re.search(r'tt.nv_tma_desc = 1', x) - if tma is not None: - return 'nvTmaDesc' - x = re.sub(r' {[^}]+}', '', x) - if match is not None: - return '*' + convert_type_repr(match.group(1)) - return x - - -def _get_num_warps_from_ir_str(src: str): - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if - # e.g. someone has an instruction (not module) attribute named "num-warps". - num_warps_matches = re.findall(ttgir_num_warps_pattern, src) - assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - num_warps = int(num_warps_matches[0]) - return num_warps - - -class ASTSource: - - def __init__(self, fn, signature, constants=None, attrs=None) -> None: - from triton.backends.compiler import AttrsDescriptor - self.fn = fn - self.ext = "ttir" - self.name = fn.__name__ - self.signature = signature - self.constants = constants - self.attrs = attrs - if isinstance(self.signature, str): - self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} - else: - for k in self.signature.keys(): - if not isinstance(k, str): - raise TypeError("Signature keys must be string") - if self.constants is None: - self.constants = {} - else: - for k in self.constants.keys(): - if not isinstance(k, str): - raise TypeError("Constants keys must be string") - if self.attrs is None: - self.attrs = AttrsDescriptor() - - def hash(self): - sorted_sig = [v for k, v in sorted(self.signature.items())] - # Note - we stringify the keys here to allow sorting to work for cases - # where constants have mixed int/str keys. - sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) - key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" - return hashlib.sha256(key.encode("utf-8")).hexdigest() - - def make_ir(self, options, codegen_fns, module_map, context): - from triton.compiler.code_generator import ast_to_ttir - return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, - module_map=module_map) - - def parse_options(self): - return dict() - - -class IRSource: - - def __init__(self, path): - self.path = path - path = Path(path) - self.ext = path.suffix[1:] - self.src = path.read_text() - match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) - self.name = match.group(1) - signature = match.group(2) - types = re.findall(arg_type_pattern[self.ext], signature) - self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} - - def hash(self): - return hashlib.sha256(self.src.encode("utf-8")).hexdigest() - - def make_ir(self, options, codegen_fns, module_map, context): - module = ir.parse_mlir_module(self.path, context) - module.context = context - return module - - def parse_options(self): - if self.ext == "ttgir": - return {'num_warps': _get_num_warps_from_ir_str(self.src)} - return dict() - - -@functools.lru_cache() -def triton_key(): - from triton import __version__ - import pkgutil - TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - TRITON_PATH = os.path.dirname(TRITON_PATH) - contents = [] - # frontend - with open(__file__, "rb") as f: - contents += [hashlib.sha256(f.read()).hexdigest()] - # compiler - path_prefixes = [ - (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), - (os.path.join(TRITON_PATH, "backends"), "triton.backends."), - ] - for path, prefix in path_prefixes: - for lib in pkgutil.walk_packages([path], prefix=prefix): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha256(f.read()).hexdigest()] - - # backend - libtriton_hash = hashlib.sha256() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: - while True: - chunk = f.read(1024**2) - if not chunk: - break - libtriton_hash.update(chunk) - contents.append(libtriton_hash.hexdigest()) - # language - language_path = os.path.join(TRITON_PATH, 'language') - for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha256(f.read()).hexdigest()] - return f'{__version__}' + '-'.join(contents) - - -def parse(full_name, ext, context): - if ext == "ttir" or ext == "ttgir": - module = ir.parse_mlir_module(full_name, context) - module.context = context - return module - if ext == "llir" or ext == "ptx": - return Path(full_name).read_text() - if ext == "cubin": - return Path(full_name).read_bytes() - - -def filter_traceback(e: BaseException): - """ - Removes code_generator.py and related files from tracebacks. - - These are uninteresting to the user -- "just show me *my* code!" - """ - if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": - return - - if e.__cause__ is not None: - filter_traceback(e.__cause__) - if e.__context__ is not None: - filter_traceback(e.__context__) - - # If a user has a file that matches one of these, they're out of luck. - BAD_FILES = [ - "/triton/compiler/code_generator.py", - "/ast.py", - ] - - tb = e.__traceback__ - frames = [] - while tb is not None: - if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): - frames.append(tb) - tb = tb.tb_next - - for (cur_frame, next_frame) in zip(frames, frames[1:]): - cur_frame.tb_next = next_frame - - if not frames: - e.__traceback__ = None - else: - frames[-1].tb_next = None - e.__traceback__ = frames[0] - - -def compile(src, target=None, options=None): - from triton.backends.compiler import GPUTarget - from triton.runtime.cache import get_cache_manager, get_dump_manager, get_override_manager - from triton.runtime.driver import driver - from .errors import MLIRCompilationError - if target is None: - target = driver.active.get_current_target() - assert isinstance(target, GPUTarget), "target must be of GPUTarget type" - backend = make_backend(target) - ir_source = not isinstance(src, ASTSource) - # create backend - if ir_source: - assert isinstance(src, str), "source must be either AST or a filepath" - src = IRSource(src) - extra_options = src.parse_options() - options = backend.parse_options(dict(options or dict(), **extra_options)) - # create cache manager - env_vars = get_cache_invalidating_env_vars() - key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" - hash = hashlib.sha256(key.encode("utf-8")).hexdigest() - fn_cache_manager = get_cache_manager(hash) - # For dumping/overriding only hash the source as we want it to be independent of triton - # core changes to make it easier to track kernels by hash. - enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" - fn_override_manager = get_override_manager(src.hash()) if enable_override else None - fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None - # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. - # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". - # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate - # the file name to 150 characters to be safe. - file_name = src.name[:150] - metadata_filename = f"{file_name}.json" - metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) - always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" - if not always_compile and metadata_path is not None: - # cache hit! - metadata = json.loads(Path(metadata_path).read_text()) - return CompiledKernel(src, metadata_group, hash) - compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') - if (compile_speed_opt): - ttir_path = f"{file_name}.ttir" - if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): - # Already compile once but failed. So directly return - raise Exception("already failed once") - # initialize metadata - metadata = { - "hash": hash, - "target": target, - **options.__dict__, - **env_vars, - } - # run compilation pipeline and populate metadata - stages = dict() - backend.add_stages(stages, options) - first_stage = list(stages.keys()).index(src.ext) - # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. - if ir_source: - first_stage += 1 - context = ir.context() - ir.load_dialects(context) - backend.load_dialects(context) - codegen_fns = backend.get_codegen_implementation() - module_map = backend.get_module_map() - try: - module = src.make_ir(options, codegen_fns, module_map, context) - except Exception as e: - filter_traceback(e) - raise - use_ir_loc = os.environ.get("USE_IR_LOC", None) - for ext, compile_ir in list(stages.items())[first_stage:]: - try: - next_module = compile_ir(module, metadata) - except Exception as e: - if (ext == "ttadapter"): - stage_name = "ConvertTritonIRToLinalgIR" - elif (ext == "npubin"): - stage_name = "ConvertLinalgRToBinary" - else: - stage_name = "MLIRCompile" - raise MLIRCompilationError(stage_name, e.stderr.decode('utf-8')) - ir_filename = f"{file_name}.{ext}" - if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): - print(f"\nOverriding kernel with file {full_name}") - next_module = parse(full_name, ext, context) - metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - if fn_dump_manager is not None: - fn_dump_manager.put(next_module, ir_filename) - # use an env variable to parse ir from file - if use_ir_loc == ext: - ir_full_name = fn_cache_manager.get_file(ir_filename) - next_module.create_location_snapshot(ir_full_name) - print(f"Creating new locations for {ir_full_name}") - module = next_module - # write-back metadata - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, - binary=False) - fn_cache_manager.put_group(metadata_filename, metadata_group) - # Compilation completed, disabling multithreading in context. - # This is needed to safely finalize threads pool inside context: if current process forks before - # python GC deletes context object, thread pool in child process will be invalid, which could - # lead to child crash or hang. - context.disable_multithreading() - # return handle to compiled kernel - return CompiledKernel(src, metadata_group, hash) - - -def make_backend(target): - from triton.backends import backends - actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] - if len(actives) != 1: - raise RuntimeError( - f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") - return actives[0](target) - - -class LazyDict: - - def __init__(self, data): - self.data = data - self.extras = [] - - def get(self) -> None: - for func, args in self.extras: - self.data = self.data | func(*args) - self.extras.clear() - return self.data - - def add(self, func, args): - self.extras.append((func, args)) - - -class AsmDict(dict): - - def __missing__(self, key): - from triton.tools.disasm import get_sass - if key == "sass": - value = get_sass(self["cubin"]) - else: - raise KeyError("Unknown key: '%s'" % key) - - self[key] = value - return value - - -class CompiledKernel: - - # Hooks for external tools to monitor the execution of triton kernels - # TODO: move out of this namespace since it's a runtime thing - launch_enter_hook = None - launch_exit_hook = None - - def __init__(self, src, metadata_group, hash): - from collections import namedtuple - from triton.backends.compiler import GPUTarget - metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) - metadata = json.loads(metadata_path.read_text()) - metadata['cluster_dims'] = tuple(metadata['cluster_dims']) - # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. - target = metadata['target'] - metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) - KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) - self.metadata = KernelMetadata(**metadata) - backend = make_backend(self.metadata.target) - self.packed_metadata = backend.pack_metadata(self.metadata) - self.src = src - self.hash = hash - self.name = self.metadata.name - # stores the text of each level of IR that was generated during compilation - asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] - binary_ext = backend.binary_ext - self.asm = AsmDict({ - file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() - for file in asm_files - }) - self.kernel = self.asm[binary_ext] - # binaries are lazily initialized - # because it involves doing runtime things - # (e.g., checking amount of shared memory on current device) - self.module = None - self.function = None - - def _init_handles(self): - from triton.runtime.errors import OutOfResources - from triton.runtime.driver import driver - if self.module is not None: - return - # create launcher - self.run = driver.active.launcher_cls(self.src, self.metadata) - # not enough shared memory to run the kernel - # on NPU, get_device_properties in fact does not use the device param - # but we still need to preserve it because triton defines the API - device = driver.active.get_current_device() - max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] - if self.metadata.shared > max_shared: - raise OutOfResources(self.metadata.shared, max_shared, "shared memory") - # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` - self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( - self.name, self.kernel, self.metadata.shared, device) - - def __getattribute__(self, name): - if name == 'run': - self._init_handles() - return super().__getattribute__(name) - - def launch_metadata(self, grid, stream, *args): - if CompiledKernel.launch_enter_hook is None: - return None - ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) - if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: - return ret - arg_dict = {} - arg_idx = 0 - for i, arg_name in enumerate(self.src.fn.arg_names): - if i in self.src.fn.constexprs: - arg_dict[arg_name] = self.src.constants[arg_name] - else: - arg_dict[arg_name] = args[arg_idx] - arg_idx += 1 - ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) - return ret - - def __getitem__(self, grid): - self._init_handles() - - def runner(*args, stream=None): - from triton.runtime.driver import driver - if stream is None: - device = driver.active.get_current_device() - stream = driver.active.get_current_stream(device) - launch_metadata = self.launch_metadata(grid, stream, *args) - self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, - CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) - - return runner diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py deleted file mode 100644 index 23b6bdbf4..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py +++ /dev/null @@ -1,72 +0,0 @@ -import ast -from typing import Optional -from triton.errors import TritonError - - -class CompilationError(TritonError): - """Base class for all errors raised during compilation""" - source_line_count_max_in_message = 12 - - def _format_message(self) -> str: - node = self.node - if self.src is None: - source_excerpt = " " - else: - if hasattr(node, 'lineno'): - source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] - if source_excerpt: - source_excerpt.append(' ' * node.col_offset + '^') - source_excerpt = '\n'.join(source_excerpt) - else: - source_excerpt = " " - else: - source_excerpt = self.src - - message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( - node, 'lineno') else source_excerpt - if self.error_message: - message += '\n' + self.error_message - return message - - def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): - self.src = src - self.node = node - self.error_message = error_message - self.message = self._format_message() - - def __str__(self): - return self.message - - def __reduce__(self): - # this is necessary to make CompilationError picklable - return type(self), (self.src, self.node, self.error_message) - - -class CompileTimeAssertionFailure(CompilationError): - """Specific exception for failed tests in `static_assert` invocations""" - pass - - -class UnsupportedLanguageConstruct(CompilationError): - pass - - -class MLIRCompilationError(TritonError): - - def __init__(self, stage_name: Optional[str], message: Optional[str] = None): - self.stage_name = stage_name - self.message = f"\n" \ - f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ - f"[{self.stage_name}] encounters error:\n" \ - f"{self.filter_message(message)}" \ - f"{self.format_line_delim('[ERROR][Triton][END]')}" - - def __str__(self): - return self.message - - def filter_message(self, message): - # Content starting from "Stack dump without symbol names" means nothing to the users - return message.split("Stack dump without symbol names")[0] - - def format_line_delim(self, keyword): - return f"///------------------{keyword}------------------\n" diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/__init__.py b/third_party/ascend/triton_patch/python/triton_patch/language/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py b/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py deleted file mode 100644 index f83b18855..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import List - -TRITON_MAX_TENSOR_NUMEL = 1048576 - - -def validate_block_shape(shape: List[int]): - numel = 1 - for i, d in enumerate(shape): - if not isinstance(d, int): - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") - numel *= d - - if numel > TRITON_MAX_TENSOR_NUMEL: - raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") - return numel diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/core.py b/third_party/ascend/triton_patch/python/triton_patch/language/core.py deleted file mode 100644 index a5cdf3e43..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/language/core.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -from typing import List -from triton.language.core import _tensor_member_fn, builtin, _constexpr_to_value, tensor, constexpr -from triton.language.core import dtype as real_dtype -from triton.language import semantic as real_semantic -from triton._C.libtriton import ir -from triton.language.core import float32 -# from triton.language.core import _unwrap_if_constexpr, _unwrap_shape -from . import semantic -# from ._utils import validate_block_shape - -# class dtype(real_dtype): - -# def to_ir(self, builder: ir.builder) -> ir.type: -# if self.name in ("uint8", "uint16", "uint32", "uint64"): -# raise ValueError(f"type {self} not supported in this architecture for now.") - -# if self.name.startswith("fp8"): -# if self.name not in builder.options.supported_fp8_dtypes: -# raise ValueError(f'type {self} not supported in this architecture. ' -# f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') -# if self.name in builder.options.deprecated_fp8_dtypes: -# warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release") - -# if self.name == 'void': -# return builder.get_void_ty() -# elif self.name == 'int1': -# return builder.get_int1_ty() -# elif self.name in ('int8', 'uint8'): -# return builder.get_int8_ty() -# elif self.name in ('int16', 'uint16'): -# return builder.get_int16_ty() -# elif self.name in ('int32', 'uint32'): -# return builder.get_int32_ty() -# elif self.name in ('int64', 'uint64'): -# return builder.get_int64_ty() -# elif self.name == 'fp8e5': -# return builder.get_fp8e5_ty() -# elif self.name == 'fp8e5b16': -# return builder.get_fp8e5b16_ty() -# elif self.name == 'fp8e4nv': -# return builder.get_fp8e4nv_ty() -# elif self.name == 'fp8e4b8': -# return builder.get_fp8e4b8_ty() -# elif self.name == 'fp8e4b15': -# return builder.get_fp8e4b15_ty() -# elif self.name == 'fp16': -# return builder.get_half_ty() -# elif self.name == 'bf16': -# return builder.get_bf16_ty() -# elif self.name == 'fp32': -# return builder.get_float_ty() -# elif self.name == 'fp64': -# return builder.get_double_ty() -# raise ValueError(f'fail to convert {self} to ir type') - -# class pointer_type(dtype): - -# def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): -# element_ty = _unwrap_if_constexpr(element_ty) -# if not isinstance(element_ty, dtype): -# raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') -# self.element_ty = element_ty -# self.address_space = address_space -# self.const = const -# self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' - -# def to_ir(self, builder: ir.builder): -# return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) - -# def __str__(self): -# return self.name - -# def __repr__(self): -# return self.__str__() - -# def is_ptr(self): -# return True - -# def is_const(self): -# return self.const - -# def __eq__(self, other: pointer_type) -> bool: -# if not isinstance(other, pointer_type): -# return False -# return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const - -# def __ne__(self, other: pointer_type) -> bool: -# return not self.__eq__(other) - -# @property -# def scalar(self): -# return self - -# class block_type(dtype): - -# def __init__(self, element_ty: dtype, shape: List): -# self.element_ty = element_ty - -# # Note that block_type's shape is a list of int -# # while tensor's shape is a list of constexpr. - -# # shape can be empty ([]) when an input is a 0D tensor. -# self.shape = _unwrap_shape(shape) -# if not self.shape: -# raise TypeError('0d block_type is forbidden') - -# self.numel = validate_block_shape(self.shape) -# self.name = f'<{self.shape}, {self.element_ty}>' - -# def to_ir(self, builder: ir.builder) -> ir.block_type: -# return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) - -# def __str__(self): -# return self.name - -# def __repr__(self): -# return self.__str__() - -# def is_block(self): -# return True - -# def get_block_shapes(self) -> List[int]: -# return self.shape - -# def __eq__(self, other: block_type) -> bool: -# if not isinstance(other, block_type): -# return False -# return self.element_ty == other.element_ty and self.shape == other.shape - -# def __ne__(self, other: block_type) -> bool: -# return not self.__eq__(other) - -# @property -# def scalar(self): -# return self.element_ty - -# class function_type(dtype): - -# def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: -# self.ret_types = ret_types -# self.param_types = param_types - -# def __str__(self): -# return f'fn ({self.param_types}) -> {self.ret_types}' - -# def to_ir(self, builder: ir.builder): -# ir_param_types = [ty.to_ir(builder) for ty in self.param_types] -# ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] -# return builder.get_function_ty(ir_param_types, ret_types) - - -@builtin -def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, - _builder=None): - assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" - assert not allow_tf32, "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." - if input_precision is None: - supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions - default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" - input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) - else: - assert (input_precision not in [ - "tf32", "tf32x3" - ]), "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." - input_precision = _constexpr_to_value(input_precision) - out_dtype = _constexpr_to_value(out_dtype) - max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) - return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) - - -@_tensor_member_fn -@builtin -def gather(src, index, axis, _builder=None): - """Gather from a tensor along a given dimension. - :param src: the source tensor - :type src: Tensor - :param index: the index tensor - :type index: Tensor - :param axis: the dimension to gather along - :type axis: int - """ - axis = _constexpr_to_value(axis) - return semantic.gather(src, index, axis, _builder) - - -@_tensor_member_fn -@builtin -def insert(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to receive tensor. - :type ful: Tensor - :param sub: The tensor to be inserted. - :type sub: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - assert (len(ful.shape) > 0) - assert (len(ful.shape) == len(sub.shape)) - new_offsets = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - out = semantic.insert(ful, sub, new_offsets, sizes, strides, _builder) - return out - - -@_tensor_member_fn -@builtin -def subview(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to split. - :type ful: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - assert (len(ful.shape) > 0) - new_offsets = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - sub = semantic.subview(ful, new_offsets, sizes, strides, _builder) - return sub diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/math.py b/third_party/ascend/triton_patch/python/triton_patch/language/math.py deleted file mode 100644 index d381508fd..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/language/math.py +++ /dev/null @@ -1,140 +0,0 @@ -from triton.language import core -from triton.language.math import _check_dtype, _add_math_1arg_docstr, _add_math_2arg_docstr -from triton.language import semantic - - -@core.builtin -@_check_dtype(dtypes=["int32", "int64", "uint32"]) -@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") -def umulhi(x, y, _builder=None): - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) - x, y = core.binary_op_type_legalization(x, y, _builder) - return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("exponential") -@core._tensor_member_fn -def exp(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_exp(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("exponential (base 2)") -@core._tensor_member_fn -def exp2(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_exp2(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("natural logarithm") -@core._tensor_member_fn -def log(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_log(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("logarithm (base 2)") -@core._tensor_member_fn -def log2(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_log2(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("cosine") -@core._tensor_member_fn -def cos(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_cos(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("sine") -@core._tensor_member_fn -def sin(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_sin(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("fast square root") -@core._tensor_member_fn -def sqrt(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_sqrt(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") -@core._tensor_member_fn -def sqrt_rn(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("inverse square root") -@core._tensor_member_fn -def rsqrt(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_rsqrt(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") -def div_rn(x, y, _builder=None): - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) - x, y = core.binary_op_type_legalization(x, y, _builder) - return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("error function") -@core._tensor_member_fn -def erf(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_erf(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("error function") -@core._tensor_member_fn -def tanh(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_tanh(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("floor") -@core._tensor_member_fn -def floor(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_floor(x.handle), x.type) - - -@core.builtin -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("ceil") -@core._tensor_member_fn -def ceil(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_ceil(x.handle), x.type) diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py b/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py deleted file mode 100644 index 1d2db8c83..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py +++ /dev/null @@ -1,270 +0,0 @@ -from typing import List, Optional, Union -import numbers -import triton.language as tl -from triton._C.libtriton import ir -from triton.language.semantic import wrap_tensor, _str_to_rounding_mode, not_equal, _str_to_dot_input_precision, binary_op_type_checking_impl, integer_promote_impl - - -def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: - if not isinstance(start, int) or not isinstance(end, int): - raise ValueError("arange's arguments must be of type tl.constexpr") - is_start_int64 = bool(start >> 32) - is_end_int64 = bool(end >> 32) - if is_start_int64 or is_end_int64: - raise ValueError("arange must fit in int32") - if end <= start: - raise ValueError("arange's end argument must be greater than the start argument") - range = end - start - # if (range & (range - 1)) != 0: - # raise ValueError("arange's range must be a power of 2") - shape = [range] - ret_ty = tl.block_type(tl.int32, shape) - return tl.tensor(builder.create_make_range(start, end), ret_ty) - - -def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, - fp_downcast_rounding: Optional[str] = None) -> tl.tensor: - src_ty = input.type - if isinstance(dst_ty, tl.constexpr): - dst_ty = dst_ty.value - if isinstance(fp_downcast_rounding, tl.constexpr): - fp_downcast_rounding = fp_downcast_rounding.value - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - - # For fp downcasting default rounding mode should be RTNE, for all other conversions it should - # not be set - fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) - use_custom_rounding = False - if dst_sca_ty.is_floating() and src_sca_ty.is_floating( - ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: - if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE - elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True - else: - if fp_downcast_rounding is not None: - raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - - if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): - raise ValueError("[fp8, fp64] is unsupported on Ascend for now." - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): - assert builder.codegen_fns.get( - "convert_custom_types") is not None, "target doesn't provide conversion for this type." - return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) - # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 - # and non-default rounding modes for downcasting - if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ - (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ - use_custom_rounding: - return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) - - # bf16 <=> (not fp32) - if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ - (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): - return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) - - # Standard floating types' casting: truncation - # fp64 => fp32, fp16, bf16 - # fp32 => fp16, bf16 - truncate_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth - if truncate_fp: - return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Standard floating types' casting: extension - # fp32 => fp64 - # fp16 => fp32, fp64 - # bf16 => fp32, fp64 - ext_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth - if ext_fp: - return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting between integer types - if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ - (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): - sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tl.tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - else: - return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) - - # Casting standard floating types to integer types - if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tl.tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - elif dst_sca_ty.is_int_signed(): - return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting integer types to standard floating types - if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): - if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to integer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): - bitwidth = dst_sca_ty.int_bitwidth - if bitwidth == 64: - return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) - if bitwidth == 1: - return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) - - # Casting integer types to pointer types - if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to pointer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - - assert False, f'cannot cast {input} to {dst_ty}' - - -def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, - out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - assert lhs.type.is_block() and rhs.type.is_block() - - if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): - # All combinations of supported fp8 x fp8 are permitted - pass - else: - assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported lhs dtype {lhs.dtype}" - assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported rhs dtype {rhs.dtype}" - assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" - - if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): - lhs = cast(lhs, tl.float16, builder) - rhs = cast(rhs, tl.float16, builder) - - if input_precision is None: - input_precision = builder.options.default_dot_input_precision - - input_precision = _str_to_dot_input_precision(input_precision, builder) - - lhs_rank = len(lhs.shape) - rhs_rank = len(rhs.shape) - assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" - assert lhs.shape[-1].value == rhs.shape[ - -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" - assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." - min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) - assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ - and rhs.shape[-1].value >= min_dot_size[1], \ - f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" - if lhs.type.scalar.is_int(): - assert lhs.type.scalar == tl.int8, "only int8 supported!" - _0 = builder.get_int32(0) - ret_scalar_ty = tl.int32 - elif out_dtype.is_bf16(): - raise ValueError( - "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") - elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): - _0 = builder.get_fp32(0) - ret_scalar_ty = tl.float32 - else: - _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) - ret_scalar_ty = out_dtype - - M = lhs.type.shape[-2] - N = rhs.type.shape[-1] - K = lhs.type.shape[-1] - B = lhs.type.shape[0] if lhs_rank == 3 else None - ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) - if acc is None: - acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) - else: - acc_handle = acc.handle - assert acc.type == ret_ty - - if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): - if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): - raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") - - if max_num_imprecise_acc is not None: - tl.static_print("max_num_imprecise_acc is not supported on Ascend yet. Thus it is ignored.") - max_num_imprecise_acc = 0 - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), - ret_ty) - - -# Use Union instead of |. Becase python 3.9 does not support |. -# It will reports error: TypeError: unsupported operand type(s) for |: 'type' and 'ABCMeta' -def floordiv(input: Union[tl.tensor, numbers.Number], other: Union[tl.tensor, numbers.Number], - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - if input_scalar_ty.is_bool() or other_scalar_ty.is_bool(): - raise TypeError(f"unexpected type {input_scalar_ty}") - if input_scalar_ty.is_int() and other_scalar_ty.is_int(): - ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) - input = cast(input, ret_ty, builder) - other = cast(other, ret_ty, builder) - if ret_ty.is_int_signed(): - return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) - else: - return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) - raise TypeError(f"unexpected type {input_scalar_ty}") - - -def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - assert index.dtype.is_int(), "index must be an integer tensor" - - rank = len(src.type.shape) - assert len(index.type.shape) == rank, "source and index tensors must have the same rank" - - assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" - if axis < 0: - axis += rank - - for d in range(rank): - if d == axis: - continue - assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" - - gather = builder.create_gather(src.handle, index.handle, axis) - return wrap_tensor(gather, src.type.scalar, index.type.shape) - - -def insert(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tl.tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - new_offsets = [o.handle for o in offsets] - ret_type = tl.block_type(ful.type.scalar, ful.shape) - out = builder.create_insert(ful.handle, sub.handle, new_offsets, sizes, strides) - return tl.tensor(out, ret_type) - - -def subview(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tl.tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - new_offsets = [o.handle for o in offsets] - ret_type = tl.block_type(ful.type.scalar, sizes) - out = builder.create_slice(ful.handle, new_offsets, sizes, strides) - return tl.tensor(out, ret_type) diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/standard.py b/third_party/ascend/triton_patch/python/triton_patch/language/standard.py deleted file mode 100644 index 83e318119..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/language/standard.py +++ /dev/null @@ -1,18 +0,0 @@ -from triton.language import core -from triton.runtime.jit import jit - - -@core._tensor_member_fn -@jit -def flip(x, dim=None): - """ - Flips a tensor `x` along the dimension `dim`. - - :param x: the first input tensor - :type x: Block - :param dim: the dimension to flip along (currently only final dimension supported) - :type dim: int - """ - core.static_print("tl.flip is unsupported for now. Use libdevice.flip instead.") - core.static_assert(False) - return x diff --git a/third_party/ascend/triton_patch/python/triton_patch/patch.py b/third_party/ascend/triton_patch/python/triton_patch/patch.py deleted file mode 100644 index 4eb332414..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/patch.py +++ /dev/null @@ -1,11 +0,0 @@ -import sys -import os -from importlib.util import spec_from_file_location, module_from_spec - -triton_root = os.path.dirname(__file__) -if triton_root not in sys.path: - sys.path.append(triton_root) -triton_patch_init_path = os.path.join(triton_root, "triton_patch/__init__.py") -spec = spec_from_file_location("triton_patch", triton_patch_init_path) -module = module_from_spec(spec) -spec.loader.exec_module(module) diff --git a/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py b/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py deleted file mode 100644 index 2c41bcc46..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py +++ /dev/null @@ -1,410 +0,0 @@ -from __future__ import annotations - -import builtins -import os -import time -import inspect -from typing import Dict - -from .jit import KernelInterface - - -class Autotuner(KernelInterface): - - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - restore_value, - pre_hook=None, - post_hook=None, - prune_configs_by: Dict = None, - warmup=None, - rep=None, - use_cuda_graph=False, - do_bench=None, - ): - from triton.runtime.driver import driver - """ - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. - """ - if not configs: - self.configs = [ - Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, - reg_dec_producer=0, reg_inc_consumer=0) - ] - - else: - self.configs = configs - self.keys = key - self.cache = {} - self.arg_names = arg_names - - # Reset to zero or restore values - self.reset_to_zero = [] - if reset_to_zero is not None: - self.reset_to_zero = list(reset_to_zero) - self.restore_value = [] - if restore_value is not None: - self.restore_value = list(restore_value) - - # Hook to reset or restore for required tensors - self.pre_hook = lambda kwargs, reset_only=False: 0 - self.post_hook = lambda kwargs, exception: 0 - self.user_defined_pre_hook = False - self.user_defined_post_hook = False - if pre_hook: - self.pre_hook = pre_hook - self.user_defined_pre_hook = True - elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): - - def _pre_hook(kwargs, reset_only=False): - for name in self.reset_to_zero: - kwargs[name].zero_() - if not reset_only: - self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} - - self.pre_hook = _pre_hook - - if post_hook: - self.post_hook = post_hook - self.user_defined_post_hook = True - elif len(self.restore_value) > 0: - - def _post_hook(kwargs, exception): - for name in self.restore_value: - kwargs[name].copy_(self.restore_copies[name]) - self.restore_copies = {} - - self.post_hook = _post_hook - - self.perf_model = None - self.configs_top_k = 1.0 - self.early_config_prune = None - if prune_configs_by: - self.perf_model = prune_configs_by.get("perf_model", self.perf_model) - self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) - self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) - - self.fn = fn - self.base_fn = fn - while not inspect.isfunction(self.base_fn): - self.base_fn = self.base_fn.fn - - self.num_warmups = warmup - self.num_reps = rep - self.use_cuda_graph = use_cuda_graph - - # If we got explicitly called via the old interface, raise a warning - # and proceed with the old behavior. - if warmup is not None or rep is not None or use_cuda_graph: - import warnings - warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " - "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, - stacklevel=1) - if use_cuda_graph: - from triton.testing import do_bench_cudagraph - self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( - kernel_call, - rep=rep if rep is not None else 100, - quantiles=quantiles, - ) - return - - import triton.testing - self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( - kernel_call, - warmup=warmup if warmup is not None else 25, - rep=rep if rep is not None else 100, - quantiles=quantiles, - ) - return - - if do_bench is None: - self.do_bench = driver.active.get_benchmarker() - else: - self.do_bench = do_bench - - def _bench(self, *args, config, **meta): - from triton.runtime.errors import OutOfResources - from triton.compiler.errors import CompileTimeAssertionFailure - from ..compiler.errors import MLIRCompilationError - - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} - - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(full_nargs) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(full_nargs, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise - - self.post_hook(full_nargs, exception=None) - - try: - return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: - return [float("inf"), float("inf"), float("inf")] - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - used_cached_result = True - if len(self.configs) > 1: - all_args = {**self.nargs, **kwargs} - _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} - key = [_args[key] for key in self.keys if key in _args] - for _, arg in _args.items(): - if hasattr(arg, "dtype"): - key.append(str(arg.dtype)) - key = tuple(key) - if key not in self.cache: - # prune configs - used_cached_result = False - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} - self.pre_hook(full_nargs, reset_only=True) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: - print(f"Triton autotuning for function {self.base_fn.__name__} finished after " - f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - if config.pre_hook is not None: - full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} - config.pre_hook(full_nargs) - ret = self.fn.run( - *args, - **kwargs, - **config.all_kwargs(), - ) - self.nargs = None - return ret - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model( - **self.nargs, - **kwargs, - **config.all_kwargs(), - ) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - ret = [] - for config in self.prune_configs(kwargs): - ret.append(self.fn.warmup( - *args, - **kwargs, - **config.all_kwargs(), - )) - self.nargs = None - return ret - - -class Config: - """ - An object that represents a possible kernel configuration for the auto-tuner to try. - - :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. - :type kwargs: dict[Str, Any] - :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if - `num_warps=8`, then each kernel instance will be automatically parallelized to - cooperatively execute using `8 * 32 = 256` threads. - :type num_warps: int - :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. - Mostly useful for matrix multiplication workloads on SM80+ GPUs. - :type num_ctas: int - :ivar num_ctas: number of blocks in a block cluster. SM90+ only. - :type maxnreg: Optional[int] - :ivar maxnreg: maximum number of registers one thread can use. Corresponds - to ptx .maxnreg directive. Not supported on all platforms. - :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this - function are args. - """ - - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, - reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): - self.kwargs = kwargs - self.num_warps = num_warps - self.num_ctas = num_ctas - self.num_stages = num_stages - self.num_buffers_warp_spec = num_buffers_warp_spec - self.num_consumer_groups = num_consumer_groups - self.reg_dec_producer = reg_dec_producer - self.reg_inc_consumer = reg_inc_consumer - self.maxnreg = maxnreg - self.pre_hook = pre_hook - - def all_kwargs(self): - return { - **self.kwargs, **{ - k: v - for (k, v) in ( - ("num_warps", self.num_warps), - ("num_ctas", self.num_ctas), - ("num_stages", self.num_stages), - ("num_buffers_warp_spec", self.num_buffers_warp_spec), - ("num_consumer_groups", self.num_consumer_groups), - ("reg_dec_producer", self.reg_dec_producer), - ("reg_inc_consumer", self.reg_inc_consumer), - ("maxnreg", self.maxnreg), - ) if v is not None - } - } - - def __str__(self): - res = [] - for k, v in self.kwargs.items(): - res.append(f"{k}: {v}") - res.append(f"num_warps: {self.num_warps}") - res.append(f"num_ctas: {self.num_ctas}") - res.append(f"num_stages: {self.num_stages}") - res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") - res.append(f"num_consumer_groups: {self.num_consumer_groups}") - res.append(f"reg_dec_producer: {self.reg_dec_producer}") - res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") - res.append(f"maxnreg: {self.maxnreg}") - return ", ".join(res) - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, - warmup=None, rep=None, use_cuda_graph=False, do_bench=None): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - - .. highlight:: python - .. code-block:: python - - @triton.autotune(configs=[ - triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple times. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - resets the value of the provided tensor to `zero` before running any configuration. - - If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to - :code:`"1"`, Triton will print a message to stdout after autotuning each - kernel, including the time spent autotuning and the best configuration. - - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - :param restore_value: a list of argument names whose value will be restored after evaluating any configs. - :type restore_value: list[str] - :param pre_hook: a function that will be called before the kernel is called. - This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. - 'kwargs': a dict of all arguments passed to the kernel. - 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. - :type pre_hook: lambda args, reset_only - :param post_hook: a function that will be called after the kernel is called. - This overrides the default post_hook used for 'restore_value'. - 'kwargs': a dict of all arguments passed to the kernel. - 'exception': the exception raised by the kernel in case of a compilation or runtime error. - :type post_hook: lambda args, exception - :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). - :type warmup: int - :param rep: repetition time (in ms) to pass to benchmarking (deprecated). - :type rep: int - :param do_bench: a benchmark function to measure the time of each run. - :type do_bench: lambda fn, quantiles - """ - - def decorator(fn): - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph) - - return decorator - - -class Heuristics(KernelInterface): - - def __init__(self, fn, arg_names, values) -> None: - self.fn = fn - self.values = values - self.arg_names = arg_names - - def run(self, *args, **kwargs): - for v, heur in self.values.items(): - kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) - return self.fn.run(*args, **kwargs) - - -def heuristics(values): - """ - Decorator for specifying how the values of certain meta-parameters may be computed. - This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - - .. highlight:: python - .. code-block:: python - - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. - each such function takes a list of positional arguments as input. - :type values: dict[str, Callable[[list[Any]], Any]] - """ - - def decorator(fn): - return Heuristics(fn, fn.arg_names, values) - - return decorator diff --git a/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py b/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py deleted file mode 100644 index 1b79635e3..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py +++ /dev/null @@ -1,952 +0,0 @@ -from __future__ import annotations, division -import ast -import hashlib -import inspect -import itertools -import os -import re -import textwrap -from collections import defaultdict -from functools import cached_property -from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple -from types import ModuleType - -TRITON_MODULE = __name__[:-len(".runtime.jit")] - -T = TypeVar("T") - -# ----------------------------------------------------------------------------- -# Dependencies Finder -# ----------------------------------------------------------------------------- - - -class DependenciesFinder(ast.NodeVisitor): - """ - This AST visitor is used to find dependencies of a JITFunction. This can - be used to invalidate a JITFunction's hash when its source code -- or - that of its dependencies -- changes. - - This visitor also keeps track of the global variables touched by the - JITFunction. When we launch the kernel, we check that these have the same - values as they did when we ran this visitor. If not, we raise an error (or - otherwise we could recompile). - """ - - def __init__(self, name, globals, src) -> None: - super().__init__() - self.name = name - self.hasher = hashlib.sha256(src.encode("utf-8")) - - # This function's __globals__ dict. - self.globals = globals - - # Python builtins that can be accessed from Triton kernels. - self.supported_python_builtins = { - 'float', - 'getattr', - 'int', - 'isinstance', - 'len', - 'list', - 'max', - 'min', - 'print', - 'range', - } - - # used_global_vals tells us which global variables are used by this - # function and all those it transitively calls, plus the values of those - # variables when each function was initially run. (That is, if A calls - # C, and B calls C, then the values for C in used_global_vals will be - # from the first time C was run, either by A or B.) - # - # Each function may have a different __globals__ dict, so the global - # variable `foo` may actually have a different value in the different - # functions. Thus this map is actually - # (var_name, id(__globals__)) -> (var_value, __globals__). - self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} - - self.visiting_arg_default_value = False - - @property - def ret(self): - return self.hasher.hexdigest() - - def _is_triton_builtin(self, node, func): - if inspect.isbuiltin(node.func): - return True - module = getattr(func, "__module__", "") - return module.startswith(TRITON_MODULE) - - def _update_hash(self, func): - if isinstance(func, JITFunction): - # Merge our used_global_vals with those of the called function, - # after checking that all overlapping values are consistent. - for k in self.used_global_vals.keys() & func.used_global_vals.keys(): - var_name, _ = k - v1, _ = self.used_global_vals[k] - v2, _ = func.used_global_vals[k] - if v1 != v2: - raise RuntimeError( - f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." - ) - self.used_global_vals.update(func.used_global_vals) - # update hash - func_key = func.cache_key - func_key += str(getattr(func, "noinline", False)) - self.hasher.update(func_key.encode("utf-8")) - - def visit_Name(self, node): - if type(node.ctx) is ast.Store: - return node.id - - if node.id in self.local_names: - # The global name is hidden by the local name. - return None - - val = self.globals.get(node.id, None) - - # Only keep track of "interesting" global variables, that non-evil users - # might change. Don't consider functions, modules, builtins, etc. This - # helps keep the list of vars we have to check small. - if (val is not None # - # Python default arguments are resolved only once, when the - # function is defined. So if you do `foo(a=A)` and the value of - # A changes, foo will still use the old value of A. - and not self.visiting_arg_default_value - # It would be pretty evil if someone did `import x` and then - # `x = blah`. - and type(val) is not ModuleType - # It would be pretty evil if we used function `foo` inside of - # `bar` and then someone did `foo = baz`. - and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # - and node.id not in self.supported_python_builtins): - self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) - - self._update_hash(val) - return val - - def visit_Tuple(self, node): - # We need to explicitly return the tuple values so that visit_Assign can - # access them in the case of `a, b = ...`. - return [self.visit(elt) for elt in node.elts] - - def visit_Attribute(self, node): - lhs = self.visit(node.value) - while isinstance(lhs, ast.Attribute): - lhs = self.visit(lhs.value) - if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): - return None - ret = getattr(lhs, node.attr) - self._update_hash(ret) - return ret - - def visit_FunctionDef(self, node): - # Save the local name, which may hide the global name. - self.local_names = {arg.arg for arg in node.args.args} - self.generic_visit(node) - - def visit_arguments(self, node): - # The purpose of this function is to visit everything in `arguments` - # just like `generic_visit`, except when we're visiting default values - # (i.e. the `foo` part of `def fn(x = foo)`), we set - # self.visiting_arg_default_value = True. This allows visit_Name to be - # aware that we're inside function default values, which have special - # semantics. - - # According to the AST docs, the arguments node has the following structure. - # - # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, - # expr* kw_defaults, arg? kwarg, expr* defaults) - def visit_defaults(defaults): - try: - assert not self.visiting_arg_default_value - self.visiting_arg_default_value = True - for expr in defaults: - if expr is not None: - self.visit(expr) - finally: - self.visiting_arg_default_value = False - - for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): - self.visit(arg) - - visit_defaults(node.kw_defaults) - - if node.kwarg is not None: - self.visit(node.kwarg) - - visit_defaults(node.defaults) - - def visitAssnTarget(self, node): - # Target is either a single string, or a list of strings (if the assn - # target is a tuple). - target = self.visit(node) - if isinstance(target, list): - self.local_names |= set(target) - else: - self.local_names.add(target) - - def visit_Assign(self, node): - if len(node.targets) != 1: - # TODO(jlebar): I don't actually know how to hit this. You don't - # get it from `a, b = ...` -- in that case, node.targets is a single - # Tuple, and in fact we *do* need to handle that case if we want - # existing code to work. - raise TypeError("Simultaneous multiple assignment is not supported.") - - self.visitAssnTarget(node.targets[0]) - - # This will re-visit the target, but that's OK. - self.generic_visit(node) - - def visit_AnnAssign(self, node): - self.visitAssnTarget(node.target) - - # This will re-visit the target, but that's OK. - self.generic_visit(node) - - def visit_For(self, node): - self.visitAssnTarget(node.target) - - # This will re-visit the target, but that's fine. - self.generic_visit(node) - - -# ----------------------------------------------------------------------------- -# JITFunction -# ----------------------------------------------------------------------------- - - -def _normalize_ty(ty) -> str: - if isinstance(ty, type): - return ty.__name__ - elif isinstance(ty, str): - return ty - return repr(ty) - - -class KernelParam: - """Represents a parameter (name plus metadata) to a @jit'ed function.""" - - def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, - do_not_specialize_on_alignment: bool): - self.num = num - self._param = param - self.do_not_specialize = do_not_specialize - self.do_not_specialize_on_alignment = do_not_specialize_on_alignment - - @cached_property - def name(self): - return self._param.name - - @cached_property - def annotation(self): - if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: - return "" - return _normalize_ty(self._param.annotation) - - @cached_property - def annotation_type(self): - annotation = self.annotation - for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: - width = annotation[annotation.find(ty1) + len(ty1):] - if width and ty1 in annotation: - return f"{ty2}{width}" - if annotation == "bool": - return "u1" - return "" - - @cached_property - def is_constexpr(self): - return "constexpr" in self.annotation - - @cached_property - def is_const(self): - return "const" in self.annotation and not self.is_constexpr - - @property - def default(self): - return self._param.default - - @property - def has_default(self): - return self._param.default != inspect.Parameter.empty - - -def compute_spec_key(v, align): - - if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): - return "D" - elif isinstance(v, int): - # bool is a subclass of int, so we don't check explicitly above. - if align and (v % 16 == 0): - return "D" - elif v == 1: - return "1" - return "N" - - -dtype2str = {} - - -def mangle_type(arg, is_const=False): - - if arg is None: - return "none" - elif isinstance(arg, bool): - return "i1" - elif isinstance(arg, int): - if -(2**31) <= arg and arg <= 2**31 - 1: - return "i32" - elif 2**63 <= arg and arg <= 2**64 - 1: - return "u64" - else: - return "i64" - elif isinstance(arg, float): - return "fp32" - elif hasattr(arg, "tma_desc_cpu_ptr"): - return "nvTmaDesc" - else: - # dtypes are hashable so we can memoize this mapping: - dsk = (arg.dtype, is_const) - res = dtype2str.get(dsk, None) - if res is None: - res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] - dtype2str[dsk] = res - return res - - -class KernelInterface(Generic[T]): - run: T - - def __getitem__(self, grid) -> T: - """ - A JIT function is launched with: fn[grid](*args, **kwargs). - Hence JITFunction.__getitem__ returns a callable proxy that - memorizes the grid. - """ - return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) - # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) - - -def serialize_specialization_data(name, signature, constants, attrs, options, key): - constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} - import json - obj = { - 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': - options.__dict__, 'key': key - } - serialized_obj = json.dumps(obj) - return serialized_obj - - -def create_function_from_signature(sig, kparams, backend): - """ - Equivalent to sig.bind followed by apply_defaults. This generates a - native Python function (using exec) which can be memoized on a per-kernel - basis to avoid having to run these expensive functions -- which constitute - much of the kernel launch overhead -- every time we run the kernel. - """ - - assert len(sig.parameters) == len(kparams) - - # Create the function argument list and the dict entries for the return statement - func_args = [] - dict_entries = [] - constexpr_vals = [] - non_constexpr_vals = [] - signature_types = [] - specialisations = [] - - for ((name, sp), kp) in zip(sig.parameters.items(), kparams): - if sp.default is inspect.Parameter.empty: - func_args.append(name) - dict_entries.append(f"'{name}': {name}") - else: - func_args.append(f"{name}=default_{name}") - dict_entries.append(f"'{name}': {name}") - if kp.is_constexpr: - constexpr_vals.append(name) - else: - non_constexpr_vals.append(name) - if not kp.do_not_specialize: - if not kp.do_not_specialize_on_alignment: - specialisations.append('compute_spec_key(%s, align=True)' % name) - else: - specialisations.append('compute_spec_key(%s, align=False)' % name) - if kp.annotation_type: - signature_types.append('"%s"' % kp.annotation_type) - else: - signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) - - cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) - constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) - non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) - - func_args.append('**excess_kwargs') - - # Join all arguments into a function definition string - args_str = ', '.join(func_args) - dict_str = ', '.join(dict_entries) - func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( - args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) - - # Prepare defaults to be inserted into function namespace - func_namespace = { - f"default_{name}": param.default - for name, param in sig.parameters.items() - if param.default is not inspect.Parameter.empty - } - - func_namespace['mangle_type'] = mangle_type - func_namespace['compute_spec_key'] = backend.compute_spec_key - - # Execute the function string in func_namespace to create the function - exec(func_body, func_namespace) - - # Extract the newly created function from the namespace - return func_namespace['dynamic_func'] - - -type_canonicalisation_dict = { - "bool": "i1", - "float8e4nv": "fp8e4nv", - "float8e5": "fp8e5", - "float8e4b15": "fp8e4b15", - "float8_e4m3fn": "fp8e4nv", - "float8e4b8": "fp8e4b8", - "float8_e4m3fnuz": "fp8e4b8", - "float8_e5m2": "fp8e5", - "float8e5b16": "fp8e5b16", - "float8_e5m2fnuz": "fp8e5b16", - "float16": "fp16", - "bfloat16": "bf16", - "float32": "fp32", - "float64": "fp64", - "int8": "i8", - "int16": "i16", - "int32": "i32", - "int64": "i64", - "uint8": "u8", - "uint16": "u16", - "uint32": "u32", - "uint64": "u64", -} - -for v in list(type_canonicalisation_dict.values()): - type_canonicalisation_dict[v] = v - - -class JITFunction(KernelInterface[T]): - # Hook for inspecting compiled functions and modules - cache_hook = None - # Hook to signal that a kernel is done compiling and inspect compiled function. - # cache_hook will always be called before compilation and compiled_hook after. - compiled_hook = None - - @staticmethod - def _key_of(arg): - if hasattr(arg, "dtype"): - return arg.dtype - elif isinstance(arg, bool): - return "i1" - elif isinstance(arg, int): - if -(2**31) <= arg and arg <= 2**31 - 1: - return "i32" - elif 2**63 <= arg and arg <= 2**64 - 1: - return "u64" - else: - return "i64" - elif isinstance(arg, float): - return "fp32" - elif arg is None: - return None - else: - raise TypeError(f"Unsupported type {type(arg)} for {arg}") - - @staticmethod - def _type_of(key, is_const=False): - # `None` is nullptr. Implicitly convert to *i8. - if key is None: - return "*i8" - elif isinstance(key, str): - return key - - dtype_str = str(key).split(".")[-1] - dtype_str = type_canonicalisation_dict[dtype_str] - const_str = "*k" if is_const else "*" - return const_str + dtype_str - - def _make_constants(self, constexpr_key): - constants = dict(zip(self.constexprs, constexpr_key)) - return constants - - def _call_hook( - self, - key, - signature, - device, - constants, - options, - configs, - is_warmup, - before, - ): - hook = JITFunction.cache_hook if before else JITFunction.compiled_hook - if hook is None: - return False - - name = self.fn.__name__ - module = self.fn.__module__ - arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) - repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" - - class JitFunctionInfo: - - def __init__(self, module, name, jit_function): - self.module = module - self.name = name - self.jit_function = jit_function - pass - - specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) - - kwargs = { - 'signature': signature, - 'device': device, - 'constants': constants, - 'num_warps': options.num_warps, - 'num_ctas': options.num_ctas, - 'num_stages': options.num_stages, - 'enable_fp_fusion': options.enable_fp_fusion, - 'extern_libs': options.extern_libs, - 'configs': configs, - 'specialization_data': specialization_data, - 'is_warmup': is_warmup, - } - - return hook( - key=key, - repr=repr, - fn=JitFunctionInfo(module, name, self), - compile={"key": key, **kwargs}, - is_manual_warmup=is_warmup, - already_compiled=False, - ) - - def add_pre_run_hook(self, hook): - ''' - Add a hook that will be executed prior to the execution of run - function with args and kwargs passed into the kernel - ''' - assert callable(hook) - self.pre_run_hooks.append(hook) - - def create_binder(self, backend): - """ - Precompute as much as possible. - """ - from ..compiler.compiler import CompiledKernel, compile, ASTSource, make_backend - self.CompiledKernel = CompiledKernel - self.compile = compile - self.ASTSource = ASTSource - self.make_backend = make_backend - self.binder = create_function_from_signature(self.signature, self.params, backend) - self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] - self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] - self.specialised_indices = [ - i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) - ] - - def run(self, *args, grid, warmup, **kwargs): - from triton.runtime.driver import driver - kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" - - # parse options - from ..compiler.compiler import make_backend - device = driver.active.get_current_device() - stream = driver.active.get_current_stream(device) - target = driver.active.get_current_target() - backend = make_backend(target) - - # Execute pre run hooks with args and kwargs - for hook in self.pre_run_hooks: - hook(*args, **kwargs) - - if self.binder is None: - self.create_binder(backend) - - bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) - - # compute cache key - key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) - kernel = self.cache[device].get(key, None) - - if kernel is None: - # Kernel is not cached; we have to compile. - options = backend.parse_options(kwargs) - - # deprecated arguments - assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" - assert "device" not in kwargs, "device option is deprecated; current device will be used" - assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" - for k in excess_kwargs: - if k not in options.__dict__: - raise KeyError("Keyword argument %s was specified but unrecognised" % k) - - bound_vals = tuple(bound_args.values()) - - # `None` is nullptr. Implicitly convert to *i8. This needs to be - # done here rather than when we build the signature as otherwise - # the kernel cache key could not distinguish between byte pointers - # and None arguments, resulting in a downstream mismatch: - sigkeys = [self.params[i].name for i in self.non_constexpr_indices] - sigvals = sig_and_spec[:len(sigkeys)] - signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - - configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) - constant_params = configs[0].get_constants() - constants = { - p.name: v - for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or (p.num in constant_params) or v is None - } - for i, arg in constants.items(): - if callable(arg): - raise TypeError(f"Callable constexpr at index {i} is not supported") - - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): - return None - # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) - self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) - - # Check that used global values have not changed. - not_present = object() - for (name, _), (val, globals_dict) in self.used_global_vals.items(): - if (newVal := globals_dict.get(name, not_present)) != val: - raise RuntimeError( - f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") - - if not warmup: - # canonicalize grid - assert grid is not None - if callable(grid): - # Arguments are passed as a dict to `grid`, by contract. - # TODO(jlebar): In the new launch API, pass the compiler flags as a - # second parameter to `grid`. - grid = grid(bound_args) - grid_size = len(grid) - grid_0 = grid[0] - grid_1 = grid[1] if grid_size > 1 else 1 - grid_2 = grid[2] if grid_size > 2 else 1 - - # launch kernel - launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) - kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, - self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) - return kernel - - def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, - noinline=None, repr=None, launch_metadata=None): - do_not_specialize = do_not_specialize if do_not_specialize else [] - do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] - - self.fn = fn - self.module = fn.__module__ - self.version = version - self.signature = inspect.signature(fn) - self.do_not_specialize = do_not_specialize - self.do_not_specialize_on_alignment = do_not_specialize_on_alignment - self.starting_line_number = inspect.getsourcelines(fn)[1] - self.repr = lambda _: fn.__name__ if repr is None else repr(_) - self.launch_metadata = launch_metadata - - self.binder = None - - self.params = [] - for i, param in enumerate(self.signature.parameters.values()): - dns = i in do_not_specialize or param.name in do_not_specialize - dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment - self.params.append(KernelParam(i, param, dns, dns_oa)) - - # function source code (without decorators) - self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] - # cache of just-in-time compiled kernels - self.cache = defaultdict(dict) - self.hash = None - - # Map of global variables used by the function and any functions it - # transitively calls, plus their values. The values are collected when - # the function is first compiled. Then every time we run the function, - # we check that the values of the globals match what's expected, - # otherwise we raise an error. - # - # Different functions can have different __globals__ maps, so the map - # key is actually (var name, id(__globals__)), and the map value is - # (value, __globals__). - self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} - - # JITFunction can be instantiated as kernel - # when called with a grid using __getitem__ - self.kernel = None - self.noinline = noinline - - # TODO(jlebar): Remove uses of these fields outside this file, then - # remove the fields here. - self.arg_names = [p.name for p in self.params] - self.constexprs = [p.num for p in self.params if p.is_constexpr] - - # Hooks that will be called prior to executing "run" - self.pre_run_hooks = [] - - # reuse docs of wrapped function - self.__doc__ = fn.__doc__ - self.__name__ = fn.__name__ - self.__globals__ = fn.__globals__ - self.__module__ = fn.__module__ - - @property - def cache_key(self): - # TODO : hash should be attribute of `self` - if self.hash is None: - dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) - dependencies_finder.visit(self.parse()) - self.hash = dependencies_finder.ret + str(self.starting_line_number) - self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) - return self.hash - - def warmup(self, *args, grid, **kwargs): - return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) - - def preload(self, specialization_data): - from ..compiler.compiler import compile, ASTSource - from triton.backends.compiler import AttrsDescriptor - import json - import triton.language as tl - from triton.runtime.driver import driver - device = driver.active.get_current_device() - deserialized_obj = json.loads(specialization_data) - if deserialized_obj['name'] != self.fn.__name__: - raise RuntimeError( - f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") - constants = { - key: tl.dtype(value) if tl.dtype.is_dtype(value) else value - for key, value in deserialized_obj['constants'].items() - } - signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) - options = { - key: tuple(value) if isinstance(value, list) else value - for key, value in deserialized_obj['options'].items() - } - key = deserialized_obj['key'] - kernel = compile(src, None, options) - self.cache[device][key] = kernel - return kernel - - # we do not parse `src` in the constructor because - # the user might want to monkey-patch self.src dynamically. - # Our unit tests do this, for example. - def parse(self): - tree = ast.parse(self.src) - assert isinstance(tree, ast.Module) - assert len(tree.body) == 1 - assert isinstance(tree.body[0], ast.FunctionDef) - return tree - - def __call__(self, *args, **kwargs): - raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") - - def __setattr__(self, name, value): - super(JITFunction, self).__setattr__(name, value) - # - when `.src` attribute is set, cache path needs - # to be reinitialized - if name == "src": - self.hash = None - - def __repr__(self): - return f"JITFunction({self.module}:{self.fn.__name__})" - - -# ----------------------------------------------------------------------------- -# `jit` decorator -# ----------------------------------------------------------------------------- - - -@overload -def jit(fn: T) -> JITFunction[T]: - ... - - -@overload -def jit( - *, - version=None, - repr: Optional[Callable] = None, - launch_metadata: Optional[Callable] = None, - do_not_specialize: Optional[Iterable[int]] = None, - do_not_specialize_on_alignment: Optional[Iterable[int]] = None, - debug: Optional[bool] = None, - noinline: Optional[bool] = None, -) -> Callable[[T], JITFunction[T]]: - ... - - -def jit( - fn: Optional[T] = None, - *, - version=None, - repr: Optional[Callable] = None, - launch_metadata: Optional[Callable] = None, - do_not_specialize: Optional[Iterable[int]] = None, - do_not_specialize_on_alignment: Optional[Iterable[int]] = None, - debug: Optional[bool] = None, - noinline: Optional[bool] = None, -) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: - """ - Decorator for JIT-compiling a function using the Triton compiler. - - :note: When a jit'd function is called, arguments are - implicitly converted to pointers if they have a :code:`.data_ptr()` method - and a `.dtype` attribute. - - :note: This function will be compiled and run on the GPU. It will only have access to: - - * python primitives, - * builtins within the triton package, - * arguments to this function, - * other jit'd functions - - :param fn: the function to be jit-compiled - :type fn: Callable - """ - - def decorator(fn: T) -> JITFunction[T]: - assert callable(fn) - if os.getenv("TRITON_INTERPRET", "0") == "1": - from triton.runtime.interpreter import InterpretedFunction - return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, - do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, - noinline=noinline, repr=repr, launch_metadata=launch_metadata) - else: - return JITFunction( - fn, - version=version, - do_not_specialize=do_not_specialize, - do_not_specialize_on_alignment=do_not_specialize_on_alignment, - debug=debug, - noinline=noinline, - repr=repr, - launch_metadata=launch_metadata, - ) - - if fn is not None: - return decorator(fn) - - else: - return decorator - - -# ----------------------------------------------------------------------------- -# Utilities for mocking tensors -# ----------------------------------------------------------------------------- - - -class MockTensor: - """ - Can be used in place of real tensors when calling: - kernel.warmup(MockTensor(torch.float32), ...) - """ - - @staticmethod - def wrap_dtype(arg): - if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": - return MockTensor(arg) - return arg - - def __init__(self, dtype): - self.dtype = dtype - - @staticmethod - def data_ptr(): - return 0 # optimistically assumes multiple of 16 - - @staticmethod - def ptr_range(): - return 0 # optimistically assumes 32 bit pointer range - - -class TensorWrapper: - - def __init__(self, base, dtype): - self.dtype = dtype - self.base = base - self.data = base.data - self.device = base.device - self.shape = self.base.shape - - def data_ptr(self): - return self.base.data_ptr() - - def stride(self, i): - return self.base.stride(i) - - def __str__(self) -> str: - return f"TensorWrapper[{self.dtype}]({self.base})" - - def element_size(self): - return self.base.element_size() - - def cpu(self): - return TensorWrapper(self.base.cpu(), self.dtype) - - def copy_(self, other): - self.base.copy_(other.base) - - def clone(self): - return TensorWrapper(self.base.clone(), self.dtype) - - def to(self, device): - return TensorWrapper(self.base.to(device), self.dtype) - - -def reinterpret(tensor, dtype): - if isinstance(tensor, TensorWrapper): - if dtype == tensor.base.dtype: - # Reinterpreting to the original interpretation; return the base. - return tensor.base - else: - # Reinterpreting a wrapped tensor to a different type. - return TensorWrapper(tensor.base, dtype) - elif hasattr(tensor, "data_ptr"): - # A new wrapper is needed around an unwrapped tensor. - return TensorWrapper(tensor, dtype) - else: - raise TypeError(f"Cannot reinterpret a {type(tensor)}.") - - -def get_jit_fn_file_line(fn): - base_fn = fn - while not isinstance(base_fn, JITFunction): - base_fn = base_fn.fn - file_name = base_fn.fn.__code__.co_filename - lines, begin_line = inspect.getsourcelines(base_fn.fn) - # Match the following pattern: - # @triton.autotune(...) <- foo.__code__.co_firstlineno - # @triton.heuristics(...) - # @triton.jit - # def foo(...): <- this line is the first line - for idx, line in enumerate(lines): - if line.strip().startswith("def "): - begin_line += idx - break - return file_name, begin_line diff --git a/third_party/ascend/triton_patch/python/triton_patch/testing.py b/third_party/ascend/triton_patch/python/triton_patch/testing.py deleted file mode 100644 index 9d68310d9..000000000 --- a/third_party/ascend/triton_patch/python/triton_patch/testing.py +++ /dev/null @@ -1,570 +0,0 @@ -import functools -import os -import subprocess -import sys -from contextlib import contextmanager -from typing import Any, Dict, List - - -def nvsmi(attrs): - attrs = ','.join(attrs) - cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(',') - ret = [int(x) for x in ret] - return ret - - -def _summarize_statistics(times, quantiles, return_mode): - import torch - if quantiles is not None: - ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() - if len(ret) == 1: - ret = ret[0] - return ret - if return_mode == "all": - return times.tolist() - return getattr(torch, return_mode)(times).item() - - -def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): - """ - Benchmark the runtime of the provided function. - - :param fn: Function to benchmark - :type fn: Callable - :param rep: Repetition time (in ms) - :type rep: int - :param grad_to_none: Reset the gradient of the provided tensor to None - :type grad_to_none: torch.tensor, optional - :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". - :type return_mode: str - """ - import torch - assert return_mode in ["min", "max", "mean", "median", "all"] - - with torch.cuda.stream(torch.cuda.Stream()): - # warmup - fn() - if grad_to_none is not None: - for x in grad_to_none: - x.detach_() - x.requires_grad_(True) - x.grad = None - # step 1 - we estimate the amount of time the kernel call takes - # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point - # but it is probably good enough - # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, - # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 - # cache flush). - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for _ in range(5): - fn() - end_event.record() - torch.cuda.synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - n_repeat = max(1, int(rep / estimate_ms)) - # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize - # host overhead - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(n_repeat): - if grad_to_none is not None: - for x in grad_to_none: - x.grad = None - fn() - torch.cuda.synchronize() - # measure time and return - ret = [] - n_retries = 10 - for _ in range(n_retries): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - g.replay() - end_event.record() - torch.cuda.synchronize() - ret += [start_event.elapsed_time(end_event) / n_repeat] - return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) - - -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): - """ - Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with - the 20-th and 80-th performance percentile. - - :param fn: Function to benchmark - :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int - :param grad_to_none: Reset the gradient of the provided tensor to None - :type grad_to_none: torch.tensor, optional - :param quantiles: Performance percentile to return in addition to the median. - :type quantiles: list[float], optional - :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str - """ - assert return_mode in ["min", "max", "mean", "median", "all"] - import torch - from triton import runtime - - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() in ('npu') - if torch.npu.is_available() and enable_bench_npu: - return do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) - - di = runtime.driver.active.get_device_interface() - - fn() - di.synchronize() - - cache = runtime.driver.active.get_empty_cache_for_benchmark() - - # Estimate the runtime of the function - start_event = di.Event(enable_timing=True) - end_event = di.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - di.synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) - start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] - # Warm-up - for _ in range(n_warmup): - fn() - # Benchmark - for i in range(n_repeat): - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients - if grad_to_none is not None: - for x in grad_to_none: - x.grad = None - # we clear the L2 cache before each run - cache.zero_() - # record time of `fn` - start_event[i].record() - fn() - end_event[i].record() - # Record clocks - di.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) - return _summarize_statistics(times, quantiles, return_mode) - - -def collect_files(base_dir): - import pandas as pd - for root, dirs, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] - if not triton_rows.empty: - return triton_rows['Avg Time(us)'].values[0] - return float('inf') - return float('inf') - - -def do_bench_npu(fn, warmup=5, active=30): - import torch - import torch_npu - import hashlib - from datetime import datetime - - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - skip_first = 1 - wait = 0 - repeat = 1 - total = skip_first + (wait + warmup + active) * repeat - md5_hash = hashlib.md5(datetime.now().strftime('%Y-%m-%d').encode('utf-8')).hexdigest() - torch_path = "./profile_result/" + md5_hash - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), record_shapes=False, - profile_memory=False, with_stack=False, with_flops=False, with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - - for i in range(total): - fn() - prof.step() - stream.synchronize() - - time = collect_files(torch_path) - - import shutil - import os - if os.path.exists(torch_path): - shutil.rmtree(torch_path) - # TODO: use logging - # print("avg time = ", time, type(time)) - return time - - -def assert_close(x, y, atol=None, rtol=None, err_msg=''): - """ - Asserts that two inputs are close within a certain tolerance. - - :param x: The first input. - :type x: scala, list, numpy.ndarray, or torch.Tensor - :param y: The second input. - :type y: scala, list, numpy.ndarray, or torch.Tensor - :param atol: The absolute tolerance. Default value is 1e-2. - :type atol: float, optional - :param rtol: The relative tolerance. Default value is 0. - :type rtol: float, optional - :param err_msg: The error message to use if the assertion fails. - :type err_msg: str - """ - import numpy as np - import torch - - # canonicalize arguments to be tensors - if not isinstance(x, torch.Tensor): - x = torch.tensor(x) - if not isinstance(y, torch.Tensor): - y = torch.tensor(y) - # absolute tolerance - if atol is None: - atol = 1e-2 - atol = atol(x.dtype) if callable(atol) else atol - # relative tolerance hook - if rtol is None: - rtol = 0. - rtol = rtol(x.dtype) if callable(rtol) else rtol - # we use numpy instead of pytorch - # as it seems more memory efficient - # pytorch tends to oom on large tensors - if isinstance(x, torch.Tensor): - if x.dtype == torch.bfloat16: - x = x.float() - x = x.cpu().detach().numpy() - if isinstance(y, torch.Tensor): - if y.dtype == torch.bfloat16: - y = y.float() - y = y.cpu().detach().numpy() - # we handle size==1 case separately as we can - # provide better error message there - if x.size > 1 or y.size > 1: - np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) - return - if not np.allclose(x, y, atol=atol, rtol=rtol): - raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') - - -class Benchmark: - """ - This class is used by the :code:`perf_report` function to generate line plots with a concise API. - """ - - def __init__( - self, - x_names: List[str], - x_vals: List[Any], - line_arg: str, - line_vals: List[Any], - line_names: List[str], - plot_name: str, - args: Dict[str, Any], - xlabel: str = '', - ylabel: str = '', - x_log: bool = False, - y_log: bool = False, - styles=None, - ): - """ - Constructor. - x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list - of scalars and there are multiple x_names, all arguments will have the same value. - If x_vals is a list of tuples/lists, each element should have the same length as - x_names. - - :param x_names: Name of the arguments that should appear on the x axis of the plot. - :type x_names: List[str] - :param x_vals: List of values to use for the arguments in :code:`x_names`. - :type x_vals: List[Any] - :param line_arg: Argument name for which different values correspond to different lines in the plot. - :type line_arg: str - :param line_vals: List of values to use for the arguments in :code:`line_arg`. - :type line_vals: List[Any] - :param line_names: Label names for the different lines. - :type line_names: List[str] - :param plot_name: Name of the plot. - :type plot_name: str - :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. - :type args: Dict[str, Any] - :param xlabel: Label for the x axis of the plot. - :type xlabel: str, optional - :param ylabel: Label for the y axis of the plot. - :type ylabel: str, optional - :param x_log: Whether the x axis should be log scale. - :type x_log: bool, optional - :param y_log: Whether the y axis should be log scale. - :type y_log: bool, optional - :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. - :type styles: list[tuple[str, str]] - """ - self.x_names = x_names - self.x_vals = x_vals - self.x_log = x_log - self.line_arg = line_arg - self.line_vals = line_vals - self.line_names = line_names - self.y_log = y_log - self.styles = styles - # plot info - self.xlabel = xlabel - self.ylabel = ylabel - self.plot_name = plot_name - self.args = args - - -class Mark: - - def __init__(self, fn, benchmarks): - self.fn = fn - self.benchmarks = benchmarks - - def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, - save_precision=6, **kwrags): - import os - - import matplotlib.pyplot as plt - import pandas as pd - y_mean = bench.line_names - y_min = [f'{x}-min' for x in bench.line_names] - y_max = [f'{x}-max' for x in bench.line_names] - x_names = list(bench.x_names) - df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) - for x in bench.x_vals: - # x can be a single value or a sequence of values. - if not isinstance(x, (list, tuple)): - x = [x for _ in x_names] - - if len(x) != len(x_names): - raise ValueError(f"Expected {len(x_names)} values, got {x}") - x_args = dict(zip(x_names, x)) - - row_mean, row_min, row_max = [], [], [] - for y in bench.line_vals: - ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) - try: - y_mean, y_min, y_max = ret - except TypeError: - y_mean, y_min, y_max = ret, None, None - row_mean += [y_mean] - row_min += [y_min] - row_max += [y_max] - df.loc[len(df)] = list(x) + row_mean + row_min + row_max - - if bench.plot_name: - plt.figure() - ax = plt.subplot() - # Plot first x value on x axis if there are multiple. - first_x = x_names[0] - for i, y in enumerate(bench.line_names): - y_min, y_max = df[y + '-min'], df[y + '-max'] - col = bench.styles[i][0] if bench.styles else None - sty = bench.styles[i][1] if bench.styles else None - ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) - if not y_min.isnull().all() and not y_max.isnull().all(): - y_min = y_min.astype(float) - y_max = y_max.astype(float) - ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) - ax.legend() - ax.set_xlabel(bench.xlabel or first_x) - ax.set_ylabel(bench.ylabel) - # ax.set_title(bench.plot_name) - ax.set_xscale("log" if bench.x_log else "linear") - ax.set_yscale("log" if bench.y_log else "linear") - if show_plots: - plt.show() - if save_path: - plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) - df = df[x_names + bench.line_names] - if diff_col and df.shape[1] == 2: - col0, col1 = df.columns.tolist() - df['Diff'] = df[col1] - df[col0] - - if print_data: - print(bench.plot_name + ':') - print(df.to_string()) - if save_path: - df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", - index=False) - return df - - def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): - has_single_bench = isinstance(self.benchmarks, Benchmark) - benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks - result_dfs = [] - if save_path: - # Create directory if it doesn't exist - os.makedirs(save_path, exist_ok=True) - html = open(os.path.join(save_path, "results.html"), "w") - html.write("\n") - for bench in benchmarks: - result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) - if save_path: - html.write(f"\n") - if save_path: - html.write("\n") - html.close() - if return_df: - if has_single_bench: - return result_dfs[0] - else: - return result_dfs - return None - - -def perf_report(benchmarks): - """ - Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. - - :param benchmarks: Benchmarking configurations. - :type benchmarks: List of :class:`Benchmark` - """ - wrapper = lambda fn: Mark(fn, benchmarks) - return wrapper - - -def get_dram_gbps(device=None): - ''' return DRAM bandwidth in GB/s ''' - import torch - - from triton.runtime import driver - if not device: - device = torch.cuda.current_device() - mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz - bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] - bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s - return bw_gbps - - -def get_max_tensorcore_tflops(dtype, clock_rate, device=None): - import torch - import triton.language as tl - from triton.runtime import driver - if not device: - device = torch.cuda.current_device() - - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 - capability = torch.cuda.get_device_capability(device) - if capability[0] < 8: - assert dtype == torch.float16 - ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores - else: - if dtype in [torch.float32, torch.int32]: - ops_per_sub_core = 256 - elif dtype in [torch.float16, torch.bfloat16, torch.int16]: - ops_per_sub_core = 512 - elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: - ops_per_sub_core = 1024 - else: - raise RuntimeError("dtype not supported") - tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 - return tflops - - -# create decorator that wraps test function into -# a cuda-memcheck system call - - -def cuda_memcheck(**target_kwargs): - - def decorator(test_fn): - - @functools.wraps(test_fn) - def wrapper(*args, **kwargs): - import psutil - ppid_name = psutil.Process(os.getppid()).name() - run_cuda_memcheck = target_kwargs.items() <= kwargs.items() - if run_cuda_memcheck and ppid_name != "cuda-memcheck": - path = os.path.realpath(test_fn.__globals__["__file__"]) - # get path of current file - env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} - assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" - test_id = kwargs['request'].node.callspec.id - cmd = f"{path}::{test_fn.__name__}[{test_id}]" - out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) - assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" - assert "ERROR SUMMARY: 0 errors" in str(out.stdout) - else: - test_fn(*args, **kwargs) - - return wrapper - - return decorator - - -@contextmanager -def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): - try: - subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) - subprocess.check_output([ - "nvidia-smi", - "-i", - "0", - f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", - ]) - subprocess.check_output([ - "nvidia-smi", - "-i", - "0", - f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", - ]) - cur_sm_clock = nvsmi(["clocks.current.sm"])[0] - cur_mem_clock = nvsmi(["clocks.current.memory"])[0] - assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" - assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" - tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock - gbps = 640 * 2 * ref_mem_clock * 1e-3 - yield tflops, gbps - finally: - subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) - subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) - subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) - - -def get_max_simd_tflops(dtype, clock_rate, device=None): - import torch - - from triton.runtime import driver - if not device: - device = torch.cuda.current_device() - - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - if dtype == torch.float32: - ops_per_sub_core = 32 # 2*16 - elif dtype == torch.float16: - ops_per_sub_core = 64 - else: - raise RuntimeError("dtype not supported") - else: - if dtype == torch.float32: - ops_per_sub_core = 32 - elif dtype in [torch.float16, torch.bfloat16]: - ops_per_sub_core = 64 - else: - raise RuntimeError("dtype not supported") - tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 - return tflops diff --git a/third_party/ascend/python/tutorials/01-vector-add.py b/third_party/tests/ascend/vector-add.py similarity index 100% rename from third_party/ascend/python/tutorials/01-vector-add.py rename to third_party/tests/ascend/vector-add.py