From ea64add03f9ca7e2f7dd20f86d61b252c7a7c67a Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Fri, 23 May 2025 09:02:07 +0000 Subject: [PATCH 01/14] [init adaptation] init adapt ascend_backend for flagtree --- CMakeLists.txt | 416 +++- python/setup.py | 16 +- python/setup_helper.py | 55 +- third_party/ascend/.gitignore | 1 + third_party/ascend/CMakeLists.txt | 12 + third_party/ascend/backend/__init__.py | 2 + third_party/ascend/backend/compiler.py | 329 +++ third_party/ascend/backend/cpu_driver.py | 185 ++ third_party/ascend/backend/device_print.h | 274 +++ third_party/ascend/backend/driver.py | 504 +++++ third_party/ascend/backend/name.conf | 1 + third_party/ascend/backend/npu_utils.cpp | 136 ++ third_party/ascend/backend/utils.py | 203 ++ .../ascend/language/ascend/__init__.py | 3 + .../ascend/language/ascend/libdevice.py | 135 ++ .../ascend/triton-adapter/CMakeLists.txt | 14 + .../triton-adapter/include/CMakeLists.txt | 1 + .../TritonToLinalg/ArgMinMaxConverter.h | 317 +++ .../include/TritonToLinalg/BlockPtrAnalysis.h | 237 +++ .../include/TritonToLinalg/CMakeLists.txt | 3 + .../TritonToLinalg/ConversionPatterns.h | 111 ++ .../TritonToLinalg/FunctionConverter.h | 38 + .../TritonToLinalg/LoadStoreConverter.h | 196 ++ .../include/TritonToLinalg/MaskAnalysis.h | 133 ++ .../include/TritonToLinalg/Passes.h | 15 + .../include/TritonToLinalg/Passes.td | 19 + .../TritonToLinalg/TritonOpConverter.h | 378 ++++ .../TritonToLinalg/TritonToLinalgPass.h | 71 + .../include/TritonToLinalg/UseAnalysis.h | 128 ++ .../include/Utils/InterleaveOptimization.h | 71 + .../triton-adapter/include/Utils/Utils.h | 148 ++ .../ascend/triton-adapter/lib/CMakeLists.txt | 2 + .../lib/TritonToLinalg/ArgMinMaxConverter.cpp | 77 + .../lib/TritonToLinalg/BlockPtrAnalysis.cpp | 1404 +++++++++++++ .../lib/TritonToLinalg/CMakeLists.txt | 29 + .../lib/TritonToLinalg/FunctionConverter.cpp | 41 + .../lib/TritonToLinalg/LoadStoreConverter.cpp | 752 +++++++ .../lib/TritonToLinalg/MaskAnalysis.cpp | 543 +++++ .../lib/TritonToLinalg/TritonOpConverter.cpp | 1149 +++++++++++ .../lib/TritonToLinalg/TritonToLinalgPass.cpp | 544 +++++ .../lib/TritonToLinalg/UseAnalysis.cpp | 362 ++++ .../triton-adapter/lib/Utils/CMakeLists.txt | 8 + .../lib/Utils/InterleaveOptimization.cpp | 662 ++++++ .../ascend/triton-adapter/lib/Utils/Utils.cpp | 752 +++++++ .../triton-adapter/tools/CMakeLists.txt | 1 + .../ascend/triton-adapter/triton_adapter.cc | 6 + third_party/ascend/triton_ascend.cpp | 11 + .../triton_patch/include/CMakeLists.txt | 1 + .../include/triton/CMakeLists.txt | 1 + .../include/triton/Dialect/CMakeLists.txt | 1 + .../triton/Dialect/Triton/CMakeLists.txt | 1 + .../triton/Dialect/Triton/IR/CMakeLists.txt | 34 + .../Dialect/Triton/IR/TritonAttrDefs.td | 137 ++ .../triton/Dialect/Triton/IR/TritonOps.td | 1286 ++++++++++++ .../ascend/triton_patch/lib/CMakeLists.txt | 1 + .../triton_patch/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/Triton/CMakeLists.txt | 1 + .../lib/Dialect/Triton/IR/CMakeLists.txt | 15 + .../lib/Dialect/Triton/IR/Dialect.cpp | 139 ++ .../lib/Dialect/Triton/IR/Ops.cpp | 1092 ++++++++++ .../lib/Dialect/Triton/IR/Traits.cpp | 239 +++ .../lib/Dialect/Triton/IR/Types.cpp | 197 ++ .../ascend/triton_patch/python/src/ir.cc | 1771 +++++++++++++++++ .../python/triton_patch/__init__.py | 5 + .../triton_patch/compiler/code_generator.py | 1303 ++++++++++++ .../python/triton_patch/compiler/compiler.py | 447 +++++ .../python/triton_patch/compiler/errors.py | 72 + .../python/triton_patch/language/__init__.py | 0 .../python/triton_patch/language/_utils.py | 15 + .../python/triton_patch/language/core.py | 229 +++ .../python/triton_patch/language/math.py | 140 ++ .../python/triton_patch/language/semantic.py | 270 +++ .../python/triton_patch/language/standard.py | 18 + .../triton_patch/python/triton_patch/patch.py | 11 + .../python/triton_patch/runtime/autotuner.py | 410 ++++ .../python/triton_patch/runtime/jit.py | 952 +++++++++ .../python/triton_patch/testing.py | 570 ++++++ third_party/ascend/utils.py | 152 ++ 78 files changed, 19988 insertions(+), 18 deletions(-) create mode 100644 third_party/ascend/.gitignore create mode 100644 third_party/ascend/CMakeLists.txt create mode 100644 third_party/ascend/backend/__init__.py create mode 100644 third_party/ascend/backend/compiler.py create mode 100644 third_party/ascend/backend/cpu_driver.py create mode 100644 third_party/ascend/backend/device_print.h create mode 100644 third_party/ascend/backend/driver.py create mode 100644 third_party/ascend/backend/name.conf create mode 100644 third_party/ascend/backend/npu_utils.cpp create mode 100644 third_party/ascend/backend/utils.py create mode 100644 third_party/ascend/language/ascend/__init__.py create mode 100644 third_party/ascend/language/ascend/libdevice.py create mode 100644 third_party/ascend/triton-adapter/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/include/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h create mode 100644 third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h create mode 100644 third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h create mode 100644 third_party/ascend/triton-adapter/include/Utils/Utils.h create mode 100644 third_party/ascend/triton-adapter/lib/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp create mode 100644 third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp create mode 100644 third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp create mode 100644 third_party/ascend/triton-adapter/lib/Utils/Utils.cpp create mode 100644 third_party/ascend/triton-adapter/tools/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/triton_adapter.cc create mode 100644 third_party/ascend/triton_ascend.cpp create mode 100644 third_party/ascend/triton_patch/include/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/include/triton/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td create mode 100644 third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td create mode 100644 third_party/ascend/triton_patch/lib/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp create mode 100644 third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp create mode 100644 third_party/ascend/triton_patch/python/src/ir.cc create mode 100644 third_party/ascend/triton_patch/python/triton_patch/__init__.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/__init__.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/_utils.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/core.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/math.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/semantic.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/language/standard.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/patch.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py create mode 100644 third_party/ascend/triton_patch/python/triton_patch/testing.py create mode 100644 third_party/ascend/utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ab99f733..031cfa206 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,20 @@ endif() project(triton) include(CTest) + +if (FLAGTREE_BACKEND STREQUAL "ascend") + set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") + set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") + set(PATCHED_TRITON_LIBRARIES + "TritonIR" + ) + set(PATCHED_TRITON_DEPENDS + "TritonTableGen" + ) + include_directories(${PATCHED_TRITON_ROOT_DIR}/include) + include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files +endif() + if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() @@ -76,7 +90,7 @@ if(TRITON_BUILD_UT) include(AddTritonUnitTest) endif() -# Compiler flags +#Compiler flags set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}") include_directories(${BACKEND_INCLUDE_DIR}) @@ -112,18 +126,55 @@ function(add_triton_object name) INTERFACE $ ) - - # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) - if(ARG_DEPENDS) - add_dependencies(${name} ${ARG_DEPENDS}) - endif() - if(ARG_LINK_LIBS) - target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + if (FLAGTREE_BACKEND STREQUAL "ascend") + set(patched_depends "") + foreach(dep ${ARG_DEPENDS}) + list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) + if(index GREATER_EQUAL 0) + list(APPEND patched_depends "Patched_${dep}") + message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}") + else() + list(APPEND patched_depends ${dep}) + endif() + endforeach() + if(patched_depends) + add_dependencies(${name} ${patched_depends}) + endif() + + set(patched_link_libs "") + foreach(lib ${ARG_LINK_LIBS}) + list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index) + if(index GREATER_EQUAL 0) + list(APPEND patched_link_libs "Patched_${lib}") + message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}") + else() + list(APPEND patched_link_libs ${lib}) + endif() + endforeach() + if(patched_link_libs) + target_link_libraries(${name} PUBLIC ${patched_link_libs}) + endif() + else() + add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + endif() endif() + endfunction(add_triton_object) set_property(GLOBAL PROPERTY TRITON_LIBS "") function(add_triton_library name) + if (FLAGTREE_BACKEND STREQUAL "ascend") + list(FIND PATCHED_TRITON_LIBRARIES "${name}" index) + if(index GREATER_EQUAL 0) + message(STATUS "Adding Patched_${name} as a lib, instead of ${name}") + return() + endif() + endif() set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) add_triton_object(${name} ${ARGN}) llvm_update_compile_flags(${name}) @@ -148,6 +199,7 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") endif() + include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) @@ -162,7 +214,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -if (FLAGTREE_BACKEND STREQUAL "cambricon") +if (FLAGTREE_BACKEND STREQUAL "cambricon" OR FLAGTREE_BACKEND STREQUAL "ascend") include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files add_subdirectory(include) @@ -172,6 +224,12 @@ elseif(NOT FLAGTREE_BACKEND) add_subdirectory(lib) endif() +if (FLAGTREE_BACKEND STREQUAL "ascend") + add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include) + add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) +endif() + + # find_package(PythonLibs REQUIRED) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") @@ -190,6 +248,10 @@ endif() if(TRITON_BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/python/src) + + set(PATCHED_PYTHON_SRC_PATH ${PATCHED_TRITON_ROOT_DIR}/python/src) + include_directories(${PYTHON_SRC_PATH}) + if(NOT (FLAGTREE_BACKEND AND EXISTS "${PYTHON_SRC_PATH}")) set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) endif() @@ -364,8 +426,13 @@ if(TRITON_BUILD_PYTHON_MODULE) if(FLAGTREE_BACKEND STREQUAL "cambricon") add_library(triton SHARED) else() + if(FLAGTREE_BACKEND STREQUAL "ascend") + set(IR_SRC ${PATCHED_PYTHON_SRC_PATH}/ir.cc) + else() + set(IR_SRC ${PYTHON_SRC_PATH}/ir.cc) + endif() add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc - ${PYTHON_SRC_PATH}/ir.cc + ${IR_SRC} ${PYTHON_SRC_PATH}/passes.cc ${PYTHON_SRC_PATH}/interpreter.cc ${PYTHON_SRC_PATH}/llvm.cc) @@ -412,3 +479,332 @@ endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) endif() + + +# cmake_minimum_required(VERSION 3.18) + +# if(POLICY CMP0116) +# # Introduced in cmake 3.20 +# # https://cmake.org/cmake/help/latest/policy/CMP0116.html +# cmake_policy(SET CMP0116 OLD) +# endif() + +# include(ExternalProject) + +# set(CMAKE_CXX_STANDARD 17) + +# set(CMAKE_INCLUDE_CURRENT_DIR ON) + +# project(triton) +# include(CTest) + +# set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +# set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") +# set(PATCHED_TRITON_LIBRARIES +# "TritonIR" +# ) +# set(PATCHED_TRITON_DEPENDS +# "TritonTableGen" +# ) + +# if(NOT WIN32) +# list(APPEND CMAKE_MODULE_PATH "${TRITON_ROOT_DIR}/cmake") +# endif() + +# # Options +# option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" OFF) +# option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON) +# option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" OFF) +# option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" OFF) +# set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") + +# # Ensure Python3 vars are set correctly +# # used conditionally in this file and by lit tests + +# # Customized release build type with assertions: TritonRelBuildWithAsserts +# set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +# set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +# set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") +# set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") + +# # Default build type +# if(NOT CMAKE_BUILD_TYPE) +# message(STATUS "Default build type: Release") +# set(CMAKE_BUILD_TYPE "Release") +# endif() + +# if(NOT WIN32) +# find_library(TERMINFO_LIBRARY tinfo) +# endif() + +# # Compiler flags +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") + +# # ######### +# # LLVM +# # ######### +# if(NOT MLIR_DIR) +# set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir) +# endif() + +# # MLIR +# find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR}) + +# list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +# list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +# include(TableGen) # required by AddMLIR +# include(AddLLVM) +# include(AddMLIR) + +# # Utilities +# function(add_triton_object name) +# cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN}) +# add_library(${name} OBJECT) +# target_sources(${name} +# PRIVATE ${ARG_UNPARSED_ARGUMENTS} +# INTERFACE $ +# ) + +# set(patched_depends "") +# foreach(dep ${ARG_DEPENDS}) +# list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) +# if(index GREATER_EQUAL 0) +# list(APPEND patched_depends "Patched_${dep}") +# message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}") +# else() +# list(APPEND patched_depends ${dep}) +# endif() +# endforeach() +# if(patched_depends) +# add_dependencies(${name} ${patched_depends}) +# endif() + +# set(patched_link_libs "") +# foreach(lib ${ARG_LINK_LIBS}) +# list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index) +# if(index GREATER_EQUAL 0) +# list(APPEND patched_link_libs "Patched_${lib}") +# message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}") +# else() +# list(APPEND patched_link_libs ${lib}) +# endif() +# endforeach() +# if(patched_link_libs) +# target_link_libraries(${name} PUBLIC ${patched_link_libs}) +# endif() + +# endfunction(add_triton_object) + +# set_property(GLOBAL PROPERTY TRITON_LIBS "") +# function(add_triton_library name) +# list(FIND PATCHED_TRITON_LIBRARIES "${name}" index) +# if(index GREATER_EQUAL 0) +# message(STATUS "Adding Patched_${name} as a lib, instead of ${name}") +# return() +# endif() +# set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) +# add_triton_object(${name} ${ARGN}) +# llvm_update_compile_flags(${name}) +# endfunction() + +# set_property(GLOBAL PROPERTY TRITON_PLUGINS "") +# function(add_triton_plugin name) +# set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) +# add_triton_object(${name} ${ARGN}) +# endfunction() + +# function(remove_component_from_property property_name component_to_remove) +# get_property(prop_value GLOBAL PROPERTY ${property_name}) +# string(REPLACE ";" ";" prop_list "${prop_value}") +# list(REMOVE_ITEM prop_list "${component_to_remove}") +# string(REPLACE ";" ";" modified_prop "${prop_list}") +# set_property(GLOBAL PROPERTY ${property_name} "${modified_prop}") +# message(STATUS "Removed '${component_to_remove}' from ${property_name}") +# endfunction() + +# # Disable warnings that show up in external code (gtest;pybind11) +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") + + +# include_directories(${TRITON_ROOT_DIR}) +# include_directories(${MLIR_INCLUDE_DIRS}) +# include_directories(${LLVM_INCLUDE_DIRS}) +# include_directories(${PATCHED_TRITON_ROOT_DIR}/include) +# include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files +# include_directories(${TRITON_ROOT_DIR}/include) +# include_directories(${PROJECT_BINARY_DIR}/triton/include) # Tablegen'd files +# include_directories(${PROJECT_SOURCE_DIR}/include) +# include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files +# # link_directories(${LLVM_LIBRARY_DIR}) +# add_subdirectory(${TRITON_ROOT_DIR}/include) +# add_subdirectory(${TRITON_ROOT_DIR}/lib) +# # remove_component_from_property(TRITON_LIBS "TritonIR") +# add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include) +# add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) + +# # find_package(PythonLibs REQUIRED) +# set(TRITON_SOURCE_DIR "${TRITON_ROOT_DIR}") +# set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") + +# # TODO: Figure out which target is sufficient to fix errors; triton is +# # apparently not enough. Currently set linking libstdc++fs for all targets +# # to support some old version GCC compilers like 8.3.0. +# if (NOT WIN32 AND NOT APPLE) +# link_libraries(stdc++fs) +# endif() + + +# # ----- + +# # ------ +# if(TRITON_BUILD_PYTHON_MODULE) +# message(STATUS "Adding Python module") +# set(PYTHON_SRC_PATH ${TRITON_ROOT_DIR}/python/src) +# set(PATCHED_PYTHON_SRC_PATH ${PATCHED_TRITON_ROOT_DIR}/python/src) +# include_directories(${PYTHON_SRC_PATH}) + +# if(PYTHON_INCLUDE_DIRS) +# # We have PYTHON_INCLUDE_DIRS set--this is what we expect when building +# # using pip install. +# include_directories(${PYTHON_INCLUDE_DIRS}) +# include_directories(${PYBIND11_INCLUDE_DIR}) +# else() +# # Otherwise, we might be building from top CMakeLists.txt directly. +# # Try to find Python and pybind11 packages. +# find_package(Python3 REQUIRED COMPONENTS Development Interpreter) +# find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") +# include_directories(${Python3_INCLUDE_DIRS}) +# include_directories(${pybind11_INCLUDE_DIR}) +# link_directories(${Python3_LIBRARY_DIRS}) +# link_libraries(${Python3_LIBRARIES}) +# add_link_options(${Python3_LINK_OPTIONS}) +# endif() + +# if (DEFINED TRITON_PLUGIN_DIRS) +# foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS}) +# # Read the plugin name under dir/backend/name.conf +# cmake_path(APPEND PLUGIN_DIR "backend" "name.conf" OUTPUT_VARIABLE PLUGIN_NAME_PATH) +# file(READ ${PLUGIN_NAME_PATH} PLUGIN_NAME) +# string(STRIP ${PLUGIN_NAME} PLUGIN_NAME) + +# list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME}) + +# # Include the plugin as part of the build, placing the build output under +# # ${TRITON_BINARY_DIR}/third_party/${PLUGIN_NAME} +# # cmake_path(APPEND TRITON_BINARY_DIR "third_party" ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT) +# message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}") +# add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT}) +# endforeach() +# endif() + +# foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) +# add_subdirectory(third_party/${CODEGEN_BACKEND}) +# endforeach() + +# get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) +# get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) +# set(TRITON_LIBRARIES +# ${triton_libs} +# ${triton_plugins} + +# # mlir +# MLIRAMDGPUDialect +# MLIRNVVMDialect +# MLIRNVVMToLLVMIRTranslation +# MLIRGPUToNVVMTransforms +# MLIRGPUToGPURuntimeTransforms +# MLIRGPUTransforms +# MLIRIR +# MLIRControlFlowToLLVM +# MLIRBytecodeWriter +# MLIRPass +# MLIRTransforms +# MLIRLLVMDialect +# MLIRSupport +# MLIRTargetLLVMIRExport +# MLIRMathToLLVM +# MLIRROCDLToLLVMIRTranslation +# MLIRGPUDialect +# MLIRSCFToControlFlow +# MLIRIndexToLLVM +# MLIRGPUToROCDLTransforms +# MLIRUBToLLVM + +# # LLVM +# LLVMPasses +# LLVMNVPTXCodeGen +# # LLVMNVPTXAsmPrinter +# LLVMAMDGPUCodeGen +# LLVMAMDGPUAsmParser + +# ) +# if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 +# CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR # macOS arm64 +# CMAKE_OSX_ARCHITECTURES MATCHES "arm64") # also macOS arm64 +# list(APPEND TRITON_LIBRARIES +# LLVMAArch64CodeGen +# LLVMAArch64AsmParser +# ) +# elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") +# list(APPEND TRITON_LIBRARIES +# LLVMX86CodeGen +# LLVMX86AsmParser +# ) +# elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le") +# list(APPEND TRITON_LIBRARIES +# LLVMPowerPCAsmParser +# LLVMPowerPCCodeGen +# ) +# else() +# message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.") +# endif() + +# # Define triton library +# string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS}) + +# if (DEFINED TRITON_PLUGIN_NAMES) +# string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES}) +# endif() + +# message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}") + +# set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") +# add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) +# add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc +# ${PATCHED_PYTHON_SRC_PATH}/ir.cc +# ${PYTHON_SRC_PATH}/passes.cc +# ${PYTHON_SRC_PATH}/interpreter.cc +# ${PYTHON_SRC_PATH}/llvm.cc) +# # Link triton with its dependencies +# target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) +# if(WIN32) +# target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) +# else() +# target_link_libraries(triton PRIVATE z) +# endif() +# target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) +# endif() + +# if (UNIX AND NOT APPLE) +# set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") +# endif() + +# if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) +# set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + +# # Check if the platform is MacOS +# if(APPLE) +# set(PYTHON_LDFLAGS "-undefined dynamic_lookup") +# endif() + +# target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) +# endif() + +# if(NOT TRITON_BUILD_PYTHON_MODULE) +# foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) +# add_subdirectory(third_party/${CODEGEN_BACKEND}) +# endforeach() +# endif() + +# add_subdirectory(${TRITON_ROOT_DIR}/third_party/f2reduce) diff --git a/python/setup.py b/python/setup.py index 263036893..7c5da7029 100644 --- a/python/setup.py +++ b/python/setup.py @@ -611,6 +611,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) + helper.post_install(self) class plugin_develop(develop): @@ -618,6 +619,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) + helper.post_install(self) class plugin_bdist_wheel(bdist_wheel): @@ -625,6 +627,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) + helper.post_install(self) class plugin_egginfo(egg_info): @@ -632,11 +635,10 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) + helper.post_install(self) -package_data_tools = ["compile.h", "compile.c"] -if helper.flagtree_backend == "xpu": - package_data_tools += ["compile_xpu.h", "compile_xpu.c"] +package_data_tools = helper.get_package_data_tools() package_data = { "triton/tools": package_data_tools, **{f"triton/backends/{b.name}": b.package_data for b in backends}, "triton/language/extra": sum( @@ -676,10 +678,10 @@ def get_packages(): "triton/backends", "triton/tools", ] - if helper.flagtree_backend == "xpu": - packages.append("triton/language/extra/xpu") - elif helper.flagtree_backend == "mthreads": - packages.append("triton/language/extra/musa") + if helper.flagtree_backend: + packages.append(f"triton/language/extra/{helper.get_device_name()}") + packages += helper.get_extra_packages() + packages += [f'triton/backends/{backend.name}' for backend in backends] packages += get_language_extra_packages() if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON diff --git a/python/setup_helper.py b/python/setup_helper.py index f7a52d1b2..106f0ce56 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -14,9 +14,11 @@ necessary_third_party = ["triton_shared"] default_backends = ["nvidia", "amd"] extend_backends = [] +plugin_backends = ["cambricon", "ascend"] ext_sourcedir = "triton/_C/" flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} @dataclass @@ -43,6 +45,51 @@ class FlagTreeBackend: }) +def get_device_name(): + return device_mapping[flagtree_backend] + + +def get_extra_packages(): + packages = [] + if flagtree_backend == 'ascend': + packages = [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_package_data_tools(): + package_data = ["compile.h", "compile.c"] + if flagtree_backend == 'xpu': + package_data += ["compile_xpu.h", "compile_xpu.c"] + return package_data + + +def post_install(self): + + def get_module(module_path): + import importlib.util + import os + module_path = os.path.abspath(module_path) + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def ascend(): + utils = get_module("../third_party/ascend/utils.py") + utils.post_install() + + code = f"{flagtree_backend}()" + try: + exec(code, globals(), locals()) + except: #noqa: E722 + pass + + class FlagTreeCache: def __init__(self): @@ -236,7 +283,7 @@ def skip_package_dir(package): @staticmethod def get_package_dir(packages): package_dict = {} - if flagtree_backend and flagtree_backend != 'cambricon': + if flagtree_backend and flagtree_backend not in plugin_backends: connection = [] backend_triton_path = f"../third_party/{flagtree_backend}/python/" for package in packages: @@ -245,6 +292,12 @@ def get_package_dir(packages): pair = (package, f"{backend_triton_path}{package}") connection.append(pair) package_dict.update(connection) + if flagtree_backend == "ascend": + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" return package_dict @staticmethod diff --git a/third_party/ascend/.gitignore b/third_party/ascend/.gitignore new file mode 100644 index 000000000..c557aabf6 --- /dev/null +++ b/third_party/ascend/.gitignore @@ -0,0 +1 @@ +triton-adapter-opt diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt new file mode 100644 index 000000000..3c1ff4337 --- /dev/null +++ b/third_party/ascend/CMakeLists.txt @@ -0,0 +1,12 @@ +add_subdirectory(triton-adapter triton-adapter) + +add_triton_plugin(TritonHUAWEI ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp) + +# Copy triton-adapter-opt to python files +add_custom_target(COPY_TRITON_ADAPTER_OPT) +add_custom_command(TARGET COPY_TRITON_ADAPTER_OPT POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + ${TRITON_ROOT_DIR}/python/triton/backends/huawei/triton-adapter-opt + DEPENDS triton-adapter-opt) +add_dependencies(TritonHUAWEI COPY_TRITON_ADAPTER_OPT) diff --git a/third_party/ascend/backend/__init__.py b/third_party/ascend/backend/__init__.py new file mode 100644 index 000000000..0eec99724 --- /dev/null +++ b/third_party/ascend/backend/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py new file mode 100644 index 000000000..226f0e78f --- /dev/null +++ b/third_party/ascend/backend/compiler.py @@ -0,0 +1,329 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes +from triton.runtime import driver +from triton.runtime.cache import get_dump_manager +from dataclasses import dataclass +import functools +from typing import Any, Union, Tuple, Dict +from types import ModuleType +from pathlib import Path +import tempfile +import os +import subprocess +import hashlib +import ctypes +from typing import Optional + +from triton.backends.huawei.utils import downgrade_llir, _get_llvm_path, _get_mlir_path, _get_triton_adapter_opt_path, \ + _get_kernel_target, _get_npucompiler_path, _is_ascend_sanitizer_enabled + + +# TODO: materialize the concrete min shape +def min_dot_size(target: GPUTarget): + # return lambda lhsType, rhsType: (16, 16, 16) + return lambda lhsType, rhsType: (1, 1, 1) + + +def make_ttir(mod, metadata, opt): + if 'hash' not in metadata: + metadata['hash'] = hashlib.md5(f"{mod}-{metadata}".encode()).hexdigest() + # the same optimize pass for triton-ir as all other backends + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + print(f"Dumping intermediate results to {dump_manager.cache_dir}") + dump_manager.put(str(mod), "kernel.ttir.mlir", binary=False) + + return mod + + +def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): + # use triton_adapter to lower Triton-MLIR to linalg + # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ttir.mlir") + dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + Path(src_path).write_text(ttir_code) + triton_adapter_opt_path = _get_triton_adapter_opt_path() + + cmd_list = [ + triton_adapter_opt_path, src_path, f'--triton-to-linalg=global-kernel=false named-ops={named_ops}', "-o", + dst_path + ] + if _is_ascend_sanitizer_enabled(): + cmd_list += ["--mlir-print-debuginfo"] # pass debug info + + ret = subprocess.run(cmd_list, capture_output=True, check=True) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(Path(dst_path).read_text(), "kernel.ttadapter.mlir", binary=False) + + return Path(dst_path).read_text() + + +def linalg_to_llir(linalg: str, metadata, opt): + with tempfile.TemporaryDirectory() as tmpdir: + ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + llmlir_path = os.path.join(tmpdir, "kernel.llir.mlir") + llir_path = os.path.join(tmpdir, "kernel.ll") + Path(ttadapter_path).write_text(linalg) + mlir_opt_path = _get_mlir_path("bin", "mlir-opt") + # TritonAdapter-MLIR to LLVM-MLIR + subprocess.check_call([ + mlir_opt_path, ttadapter_path, "--convert-linalg-to-affine-loops", "--eliminate-empty-tensors", + "--empty-tensor-to-alloc-tensor", "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", "--convert-linalg-to-loops", "--convert-scf-to-cf", "--convert-cf-to-llvm", + "--convert-arith-to-llvm", "--convert-math-to-llvm", "--convert-complex-to-llvm", + "--convert-vector-to-llvm", "--convert-index-to-llvm", "--memref-expand", "--expand-strided-metadata", + "--finalize-memref-to-llvm", "--convert-func-to-llvm", + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + "--lower-affine", "--convert-arith-to-llvm", + # Remove all unrealized casts created + "--reconcile-unrealized-casts", "-o", llmlir_path + ]) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(Path(llmlir_path).read_text(), "kernel.llir.mlir", binary=False) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_mlir_path("bin", "mlir-translate") + subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path]) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(Path(llir_path).read_text(), "kernel.ll", binary=False) + + return Path(llir_path).read_text() + + +def llir_to_cpuasm(llir: str, metadata, opt): + # add metadata at final stage + # Note: Compiled Kernel requires to estimate size of shared memory to occupy + # Currently, CPU backend requires no limit on shared memory size + metadata['shared'] = 1 + # We can get a function name (C naming) from + # LLVM-IR by getting the first "define void @". + fn_name = llir.split("define void @")[1].split("(")[0].strip() + metadata['name'] = fn_name + " cpu" + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + linked_path = os.path.join(tmpdir, "kernel_linked.ll") + dst_path = os.path.join(tmpdir, "kernel.s") + + llir = downgrade_llir(llir) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(llir, "kernel_downgrade.ll", binary=False) + + Path(src_path).write_text(llir) + + linker_path = _get_llvm_path("bin", "llvm-link") + libclc_path = _get_llvm_path("lib", "clc", "libspirv-aarch64--.bc") + subprocess.check_call([linker_path, src_path, libclc_path, "--only-needed", "-S", "-o", linked_path]) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(Path(linked_path).read_text(), "kernel_linked.ll", binary=False) + + llc_path = _get_llvm_path("bin", "llc") + subprocess.check_call([llc_path, linked_path, "-o", dst_path]) + if opt.debug: + dump_manager = get_dump_manager(metadata['hash']) + dump_manager.put(Path(dst_path).read_text(), "kernel.s", binary=False) + + # Actually it's text-format assembly. Use read_text(). + return Path(dst_path).read_text() + + +def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): + import re + # Note: Compiled Kernel requires to estimate size of shared memory to occupy + # Currently, NPU backend does not limit on shared memory + metadata['shared'] = 1 + # the mix mode is also encoded into metadata['name'] for runtime to distinguish + metadata['mix_mode'] = re.search(r'mix_mode\s*=\s*"([^"]+)"', linalg).group(1) + metadata['kernel_name'] = re.search(r'func\.func\s+@(\w+)', linalg).group(1) + # Use while space to split kernel_name and mix_mode. + # Check the function load_binary in npu_driver.py. + metadata['name'] = metadata['kernel_name'] + " " + metadata['mix_mode'] + # remove the mix_mode attribute + linalg = re.sub(r', mix_mode\s*=\s*"[^"]*"', '', linalg) + with tempfile.TemporaryDirectory() as tmpdir: + ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + Path(ttadapter_path).write_text(linalg) + bin_file = os.path.join(tmpdir, "kernel") + bin_path = os.path.join(tmpdir, "kernel_reloc.o") + callback_path = os.path.join(tmpdir, "libkernel.so") + multibuffer = metadata['multibuffer'] + _compile_option_list = [ + f"--enable-auto-multi-buffer={multibuffer}", + ] + + if _is_ascend_sanitizer_enabled(): + _compile_option_list += ["--enable-sanitizer=true"] + npu_compiler_path = _get_npucompiler_path() + if (npu_compiler_path.endswith("bishengir-compile")): + _compile_option_list += [ + "--enable-hfusion-compile=true", + "--enable-hivm-compile=true", + "--enable-triton-kernel-compile=true", + ] + cmd_list = [npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file] + ret = subprocess.run(cmd_list, capture_output=True, check=True) + if Path(callback_path).is_file(): + lib = ctypes.CDLL(callback_path) + callback_func = getattr(lib, metadata['kernel_name'] + "_infer_workspace_shape_function") + callback_func.restype = ctypes.c_int64 + callback_func.argtypes = [] + metadata['workspace_size'] = callback_func() + + return Path(bin_path).read_bytes() + + +@dataclass(frozen=True) +class NPUOptions: + debug: bool = False + sanitize_overflow: bool = True + llvm_version: int = 15 + kernel_name: str = "triton_" + + cluster_dims: tuple = (1, 1, 1) + num_warps: int = -1 + num_ctas: int = -1 + num_stages: int = 2 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") + enable_npu_compile: bool = True + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + + multibuffer: bool = True + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +@dataclass(frozen=True) +class CPUOptions: + debug: bool = False + llvm_version: int = 15 + kernel_name: str = "triton_" + + cluster_dims: tuple = (1, 1, 1) + num_warps: int = -1 + num_ctas: int = -1 + num_stages: int = -1 + + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +class HuaweiBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'cpu' or target.backend == 'npu' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + if (target.backend == "cpu"): + self.binary_ext = "cpuasm" + elif (target.backend == "npu"): + self.binary_ext = "npubin" + + def parse_options(self, opts) -> Any: + # TODO: get available targets when building options? + if self.target.backend == 'npu': + args = {k: opts[k] for k in NPUOptions.__dataclass_fields__.keys() if k in opts} + options = NPUOptions(**args) + else: + args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} + options = CPUOptions(**args) + return options + + def pack_metadata(self, metadata): + from triton.backends.huawei.utils import TRITON_PROFILER_REGISTERED + # collect necessary metadata to launch kernels + # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. + # Get this name as the kernel_name to CANN runtime. + # kernel_name is unique to Huawei backend and should not be public. + # CANN runtime limits the length of kernel name <= 50. + # Considering '\n' is appended, thus the real kernel name <= 49. + KERNEL_NAME_MAX_LEN = 49 + kernel_name_orig, mix_mode = metadata.name.split() + if (len(kernel_name_orig) > KERNEL_NAME_MAX_LEN): + kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] + # import warnings + # # red = "\x1b[31;20m" + # # reset = "\x1b[0m" + # warnings.warn(kernel_name_orig + " is truncated to " + kernel_name) + # warnings.warn("because '" + kernel_name_orig + "' exceeds torchnpu profiler's length limit < 50") + else: + kernel_name = kernel_name_orig + return { + "kernel_name": kernel_name, + "hash": metadata.hash, + "debug": metadata.debug, + "profiler_registered": TRITON_PROFILER_REGISTERED, + } + + def get_codegen_implementation(self): + # Note: a dict of functions is required to generate vendor-specific code piecies + # e.g. convert custom types like fp8e4b15 + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + return codegen_fns + + def load_dialects(self, ctx): + pass + + def add_stages(self, stages, options): + if self.target.backend == 'npu': + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + if options.enable_npu_compile: + stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options, named_ops=True) + stages["npubin"] = lambda src, metadata: linalg_to_bin_enable_npu_compile(src, metadata, options) + else: + pass + else: + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options) + stages["llir"] = lambda src, metadata: linalg_to_llir(src, metadata, options) + stages["cpuasm"] = lambda src, metadata: llir_to_cpuasm(src, metadata, options) + + @functools.lru_cache() + def hash(self): + # TODO fetch compiler version + version_key = self.target + return str(version_key) + + def get_module_map(self) -> Dict[str, ModuleType]: + return {} diff --git a/third_party/ascend/backend/cpu_driver.py b/third_party/ascend/backend/cpu_driver.py new file mode 100644 index 000000000..6bee24282 --- /dev/null +++ b/third_party/ascend/backend/cpu_driver.py @@ -0,0 +1,185 @@ +from triton.runtime.cache import get_cache_manager, get_dump_manager +from pathlib import Path +import tempfile +import os +import sysconfig +import subprocess +import importlib +from triton.backends.huawei.utils import _get_llvm_path + + +# TODO: temporarily fake CPUUtils class +class CPUUtils(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + pass + + def get_device_properties(self, device): + # temperoarily added properties to avoid triton-compiler complain + # fetch available memory at runtime + return {"max_shared_mem": 1} + + def load_binary(self, name, kernel, shared, device): + # TODO (temperoarily fake function) load a binary from binary object to device + # return value are: (mod, funcptr/handle, n_regs, n_spills) + return None, kernel, 0, 0 + + +class CPULauncher(object): + + def __init__(self, src, metadata): + kernel_name = metadata.name.split()[0] + signature = src.signature + constants = src.constants + launcher_src = generate_cpu_wrapper_src(constants, signature, kernel_name) + self.launch = compile_module(launcher_src) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CPUDriver: + + def __init__(self): + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + super().__init__() + + def get_current_target(self): + # TODO: do we rely on CPU arch? + return ("cpu", "arm-64") + + def get_current_device(self): + """ + Get current device + """ + # TODO: dummy device-getter for cpu backend + return 0 + + def set_current_device(self, device): + """ + Set current device as the given device + """ + # TODO: dummy device-setter for cpu backend + return + + def get_current_stream(self, device): + """ + Get stream for current device + """ + # TODO: dummy stream api for cpu backend. + return 0 + + +# the template is from triton-adapter HEAD. Wrapping the generated kernel assembly into a python module +def generate_cpu_wrapper_src(constants, signature, kernel_name): + + def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + def _generate_launcher(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + format = "iiiOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + # to be filled + return f""" + """ + + launcher_src = _generate_launcher(constants, signature, kernel_name) + return launcher_src + + +def compile_module(launcher_src): + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + def launch(gridX, gridY, gridZ, stream, cu_function, packed_metadata, launch_metadata, launch_enter_hook, + launch_exit_hook, *args): + # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. + # Let's compile a kernel every time. + kernel_name = packed_metadata["kernel_name"] + cache = get_cache_manager(packed_metadata["hash"]) + filename = f"{kernel_name}_cpu_launcher.so" + cache_path = cache.get_file(filename) + if cache_path is None: + asm_src = cu_function + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + if packed_metadata["debug"]: + dump_manager = get_dump_manager(packed_metadata["hash"]) + dump_manager.put(launcher_src, "kernel_cpu_launcher.cxx", binary=False) + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_bytes(asm_src) + Path(launcher_src_path).write_text(launcher_src) + # Compile it together. + subprocess.check_call([ + _get_llvm_path("bin", "clang++"), launcher_src_path, asm_src_path, f"-I{py_include_dir}", + f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path + ]) + + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location("__triton_adapter_ref_cpu_kernel_launcher", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, packed_metadata, *args) + + return launch diff --git a/third_party/ascend/backend/device_print.h b/third_party/ascend/backend/device_print.h new file mode 100644 index 000000000..b56910f3a --- /dev/null +++ b/third_party/ascend/backend/device_print.h @@ -0,0 +1,274 @@ +#ifndef TRITON_DEVICE_PRINT_H +#define TRITON_DEVICE_PRINT_H + +#include "experiment/runtime/runtime/rt.h" +#include "stdio.h" + +#define LogBufferPaddingBytes 64 +#define BlockMaxSize 16 * 1024 +#define VerifyBorder(nextField, maxBuf) \ + if (nextField > maxBuf) { \ + printf("\nWARNING: out of bound! try best to print\n"); \ + return; \ + } +#define __gm__ + +namespace TTAscDebug { + +enum NodeTy { END, NORMAL, FLOAT, INT, CHAR, STRING, POINTER }; + +struct PrintPayloadData { + __gm__ char *LogWholeRegion; + unsigned BlockNum; + size_t LogBufferSize; + PrintPayloadData() + : LogWholeRegion((__gm__ char *)nullptr), LogBufferSize(0), BlockNum(0) {} +}; + +struct DebugTunnelData { + PrintPayloadData PrintData; + DebugTunnelData() {} +}; + +void PrintFormatString(int8_t *&buf, int8_t *maxbuf) { + VerifyBorder((buf + sizeof(short)), maxbuf); + short len = *(short *)buf; + buf += sizeof(len); + VerifyBorder((buf + len), maxbuf); + printf((const char *)buf); + buf += len; +} + +template +void PrintFormatString(int8_t *&buf, int8_t *maxbuf, T param) { + VerifyBorder((buf + sizeof(short)), maxbuf); + short len = *(short *)buf; + buf += sizeof(len); + VerifyBorder((buf + len), maxbuf); + printf((const char *)buf, param); + buf += len; +} + +void AnalyzeSerializedData(int8_t *buf, int logSize, int maxSize) { + int8_t *bufEndAddr = buf + logSize; + int8_t *maxbuf = buf + maxSize; + while (buf < bufEndAddr) { + VerifyBorder((buf + sizeof(int8_t)), maxbuf); + int8_t type = *(int8_t *)buf; + while (type != NodeTy::END) { + buf += sizeof(type); + switch (type) { + default: + break; + case NodeTy::NORMAL: { + PrintFormatString(buf, maxbuf); + break; + } + case NodeTy::FLOAT: { + VerifyBorder((buf + sizeof(float)), maxbuf); + float param = *(float *)buf; + buf += sizeof(param); + PrintFormatString(buf, maxbuf, param); + break; + } + case NodeTy::INT: { + VerifyBorder((buf + sizeof(long long int)), maxbuf); + long long int param = *(long long int *)buf; + buf += sizeof(param); + PrintFormatString(buf, maxbuf, param); + break; + } + case NodeTy::STRING: { + VerifyBorder((buf + sizeof(short)), maxbuf); + short strlen = *(short *)buf; + buf += sizeof(strlen); + VerifyBorder((buf + strlen), maxbuf); + char *param = reinterpret_cast(buf); + buf += strlen; + PrintFormatString(buf, maxbuf, param); + break; + } + case NodeTy::CHAR: { + VerifyBorder((buf + sizeof(char)), maxbuf); + char param = *(char *)buf; + buf += sizeof(param); + PrintFormatString(buf, maxbuf, param); + break; + } + case NodeTy::POINTER: { + VerifyBorder((buf + 8), maxbuf); + void *param = *(void **)buf; + buf += sizeof(param); + PrintFormatString(buf, maxbuf, param); + break; + } + } + VerifyBorder((buf + sizeof(int8_t)), maxbuf); + type = *(int8_t *)buf; + } + buf += 1; + } +} + +void OnHostInitialize(PrintPayloadData *PrintData, unsigned BlockNum) { + PrintData->LogBufferSize = BlockMaxSize; + PrintData->BlockNum = BlockNum; + int WholeSize = + (PrintData->LogBufferSize + LogBufferPaddingBytes) * PrintData->BlockNum; + + void *Hbm_PrintPayloadData_start_addr = NULL; + // Not sure how to use the module_id param of rtMalloc + uint16_t ModuleId = 0; + rtError_t error = + rtMalloc(reinterpret_cast(&Hbm_PrintPayloadData_start_addr), + WholeSize, RT_MEMORY_HBM, ModuleId); + if (error != RT_ERROR_NONE) { + printf("ERROR:The memory for the printing function on the device side " + "fails to be allocated."); + printf("As a result, the printing function fails!\n"); + return; + } + PrintData->LogWholeRegion = (__gm__ char *)Hbm_PrintPayloadData_start_addr; +} + +void OnHostFinish(PrintPayloadData *PrintData, rtStream_t Stream) { + if (!PrintData->LogWholeRegion) { + return; + } + std::size_t WholeSize = + (PrintData->LogBufferSize + LogBufferPaddingBytes) * PrintData->BlockNum; + char *hostMemOut2; + // Not sure how to use the module_id param of rtMalloc + uint16_t ModuleId = 0; + rtError_t error = rtMallocHost(reinterpret_cast(&hostMemOut2), + WholeSize, ModuleId); + if (error != RT_ERROR_NONE) { + printf("ERROR:The memory for the printing function on the device side " + "fails to be allocated."); + printf("As a result, the printing function fails!\n"); + return; + } + error = rtMemcpyAsync(hostMemOut2, WholeSize, PrintData->LogWholeRegion, + WholeSize, RT_MEMCPY_DEVICE_TO_HOST, Stream); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory copy of the device print on fails,"); + printf("and the printing function is invalid!\n"); + return; + } + error = rtStreamSynchronize(Stream); + if (error != RT_ERROR_NONE) { + printf("ERROR: Synchronous waiting for the device print failed.\n"); + printf("The printing function is invalid!\n"); + return; + } + char *outRaw2 = static_cast(hostMemOut2); + const char *Line = "-------------------------------------------------------"; + // Precheck if any print data is ready + for (int B = 0; B < PrintData->BlockNum; B++) { + char *Log = + (outRaw2 + (PrintData->LogBufferSize + LogBufferPaddingBytes) * B); + size_t LogSize = *reinterpret_cast(Log); + if (LogSize > 0 && LogSize <= PrintData->LogBufferSize) { + printf("LogBufferSize of each core is : %zu Bytes\n", + PrintData->LogBufferSize); + printf("%s\n", Line); + printf("----------------------HiIPU " + "Print----------------------\n"); + printf("%s\n", Line); + break; + } + } + + for (int B = 0; B < PrintData->BlockNum; B++) { + char *Log = + (outRaw2 + (PrintData->LogBufferSize + LogBufferPaddingBytes) * B); + size_t LogSize = *reinterpret_cast(Log); + if (LogSize < 0 || LogSize > PrintData->LogBufferSize) { + printf(" LOG SIZE ERROR !!! \n"); + printf(" log size needed = %zu ", LogSize); + printf(" , buf size = %zu\n", PrintData->LogBufferSize); + LogSize = PrintData->LogBufferSize; + continue; + } + if (LogSize == 0) { + continue; + } + printf("==> Block %d, LogSize = %zu Bytes\n", B, LogSize); + int8_t *Buf = reinterpret_cast(Log + LogBufferPaddingBytes); + AnalyzeSerializedData(Buf, LogSize, PrintData->LogBufferSize); + printf("\n"); + printf("%s\n", Line); + } + error = rtFree(PrintData->LogWholeRegion); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory free of the device print fails\n"); + return; + } + error = rtFreeHost(hostMemOut2); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory free of the device print fails\n"); + return; + } +} + +DebugTunnelData *Open(unsigned BlockNum) { + DebugTunnelData debugTunnelDataForHost; + OnHostInitialize(&(debugTunnelDataForHost.PrintData), BlockNum); + void *Hbm_PrintPayloadData_start_addr = NULL; + // Not sure how to use the module_id param of rtMalloc + uint16_t ModuleId = 0; + rtError_t error = + rtMalloc(reinterpret_cast(&Hbm_PrintPayloadData_start_addr), + sizeof(debugTunnelDataForHost), RT_MEMORY_HBM, ModuleId); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory for the printing function on the device side " + "fails to be allocated."); + printf("As a result, the printing function fails!\n"); + return nullptr; + } + if (Hbm_PrintPayloadData_start_addr == nullptr) { + printf("WARNING: failed to allocate DebugTunnelData memory\n"); + return nullptr; + } + error = rtMemcpy(Hbm_PrintPayloadData_start_addr, + sizeof(debugTunnelDataForHost), &debugTunnelDataForHost, + sizeof(debugTunnelDataForHost), RT_MEMCPY_HOST_TO_DEVICE); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory copy of the device print on fails, "); + printf("and the printing function is invalid!\n"); + return nullptr; + } + return reinterpret_cast(Hbm_PrintPayloadData_start_addr); +} + +void Close(DebugTunnelData *DTData, rtStream_t Stream) { + if (!DTData) { + return; + } + DebugTunnelData debugTunnelDataForHost; + rtError_t error = rtStreamSynchronize(Stream); + if (error != RT_ERROR_NONE) { + printf("ERROR: Synchronous waiting for the device print failed.\n"); + printf("The printing function is invalid!\n"); + } + error = + rtMemcpy(&debugTunnelDataForHost, sizeof(debugTunnelDataForHost), DTData, + sizeof(debugTunnelDataForHost), RT_MEMCPY_DEVICE_TO_HOST); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory copy of the device print on fails, "); + printf("and the printing function is invalid!\n"); + return; + } + OnHostFinish(&(debugTunnelDataForHost.PrintData), Stream); + + error = rtFree(DTData); + if (error != RT_ERROR_NONE) { + printf("ERROR: The memory free of the device print fails, "); + printf("and the device print is invalid!\n"); + return; + } +} + +} // namespace TTAscDebug + +#endif diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py new file mode 100644 index 000000000..5cb59d411 --- /dev/null +++ b/third_party/ascend/backend/driver.py @@ -0,0 +1,504 @@ +from pathlib import Path +import tempfile +import os +import subprocess +import sysconfig +from typing import Optional +import functools +import hashlib +from triton.runtime.cache import get_cache_manager, get_dump_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget +from triton.backends.huawei.utils import _build_npu_ext, _check_cxx11_abi + + +class NPUUtils(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(NPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + src = Path(os.path.join(dirname, "npu_utils.cpp")).read_text() + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "npu_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "npu_utils.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build_npu_ext("npu_utils", src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("npu_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.npu_utils_mod = mod + + def load_binary(self, name, kernel, shared, device): + fnname, mix_mode = name.split() + return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode) + + @functools.lru_cache() + def get_device_properties(self, device): + # temperoarily added "max_shared_mem" properties to avoid triton-compiler complain + # fetch available memory at runtime + num_aic = self.get_aicore_num() + num_aiv = num_aic * 2 + return {"max_shared_mem": 1, "num_aicore": num_aic, "num_vectorcore": num_aiv} + + @functools.lru_cache() + def get_arch(self): + # temporarily return empty arch descriptor + return self.npu_utils_mod.get_arch() + + @functools.lru_cache() + def get_aicore_num(self): + # temporarily return empty arch descriptor + return self.npu_utils_mod.get_aicore_num() + + +class NPULauncher(object): + + def __init__(self, src, metadata): + debug_mode = metadata.debug + workspace_size = int(metadata.workspace_size) \ + if hasattr(metadata, 'workspace_size') else -1 + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + wrapper_src = generate_npu_wrapper_src(constants, signature, \ + workspace_size) + so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) + # initialize launcher + import importlib.util + spec = importlib.util.spec_from_file_location("__triton_launcher", so_launcher_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.launch = getattr(mod, "launch") + + def __call__(self, *args, **kwargs): + profiler_registered = self.launch(*args, **kwargs) + import triton + triton.backends.huawei.utils.TRITON_PROFILER_REGISTERED = True if profiler_registered == 1 else False + + +class NPUDriver(DriverBase): + + def __init__(self): + self.utils = NPUUtils() + self.launcher_cls = NPULauncher + super().__init__() + + @classmethod + def is_active(cls): + + def test_npucompiler(): + from triton.backends.huawei.utils import _get_bisheng_path + npucompiler = _get_bisheng_path() + targets = subprocess.check_output([npucompiler, "-print-targets"]).decode().strip().split() + return "hiipu64" in targets + + try: + return test_npucompiler() + except Exception as e_npucompiler: + import warnings + red = "\x1b[31;20m" + reset = "\x1b[0m" + warnings.warn(red + str(e_npucompiler) + reset) + return False + + def get_current_target(self): + backend = "npu" + arch = self.utils.get_arch() + warp_size = 0 + return GPUTarget(backend, arch, warp_size) + + def get_current_device(self): + """ + Get current device + """ + import torch + import torch_npu + return torch.npu.current_device() + + def set_current_device(self, device): + """ + Set current device as the given device + """ + import torch + import torch_npu + return torch.npu.set_device(device) + + def get_current_stream(self, device: Optional[int] = None) -> int: + """ + Get stream for current device + """ + # According to torch_npu, the content of a torch.npu.Stream is essentilly an rtStream_t + # TODO: use CANN API instead of torchnpu + import torch + import torch_npu + if device is None: + device = self.get_current_device() + return torch.npu.current_stream(device).npu_stream + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_device_interface(self): + import torch + return torch.npu + + def get_empty_cache_for_benchmark(self): + import torch + cache_size = 192 * 1024 * 1024 + return torch.empty(cache_size // 4, dtype=torch.int, device='npu') + + +def make_npu_launcher_stub(src, debug=False): + """ + Generate the launcher stub to launch the kernel + """ + # try to get cached file + so_cache_key = hashlib.sha256(src.encode("utf-8")).hexdigest() + so_cache_manager = get_cache_manager(so_cache_key) + # append the cxx11_abi value to the launcher name to avoid + # linking to a launcher with wrong cxx11_abi. + use_cxx11_abi = _check_cxx11_abi() + name = f"launcher_cxx11abi{use_cxx11_abi}" + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so_name = f"{name}{suffix}" + + if debug: + dump_manager = get_dump_manager(so_cache_key) + print(f"Dumping {name}.cxx to {dump_manager.cache_dir}") + dump_manager.put(src, f"{name}.cxx", binary=False) + + cache_path = so_cache_manager.get_file(so_name) + if cache_path is not None: + return cache_path + + with tempfile.TemporaryDirectory() as tmpdir: + if debug: + so_cache_manager.put(src, f"{name}.cxx", binary=False) + src_path = os.path.join(tmpdir, f"{name}.cxx") + with open(src_path, "w") as f: + f.write(src) + so = _build_npu_ext(name, src_path, tmpdir, kernel_launcher="torch") + if debug: + with open(so, "rb") as f: + return dump_manager.put(f.read(), so_name, binary=True) + with open(so, "rb") as f: + return so_cache_manager.put(f.read(), so_name, binary=True) + + +# the template is from triton-adapter HEAD. Wrapping the generated kernel binary into a python module +def generate_npu_wrapper_src(constants, signature, workspace_size): + import os + + def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + """ + args: + int gridX, gridY, gridZ; + rtStream_t stream; + const void *functon; + PyObject* packed_metadata, *launch_metadata; + PyObject* launch_enter_hook, *launch_exit_hook; + *args_expand + """ + format = "iiiKKOOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + + grid_info = {'X': 'i32', 'Y': 'i32', 'Z': 'i32'} + + enable_device_print = os.getenv("TRITON_DEVICE_PRINT", 'false').lower() in ('true', '1') + + return f""" +#include +#include +#include +#include + +#define PY_SSIZE_T_CLEAN +#include +#include +#include "experiment/runtime/runtime/rt.h" +{'#include "device_print.h"' if enable_device_print else ''} + +extern "C" {{ + + typedef int (* callback)(unsigned int type, void* data, unsigned int len); + extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); + extern unsigned long int MsprofSysCycleTime(); + extern int MsprofRegisterCallback(unsigned int moduleId, callback handle); + static unsigned int __MsprofFlagL0 = 0; + static unsigned int __MsprofFlagL1 = 0; + + int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) {{ + if ((CtrlData == nullptr) || (DataLen == 0U)) {{ + return 1; + }} + + if (CtrlType == 1) {{ + MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData); + if (handle->type >= 6) // 6 is not used here + return 1; + if (handle->type == 1) {{ // init - 0 , start - 1 + __MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0; + __MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0; + }} + }} + return 0; + }} +}} + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static void _launch(const char* kernelName, const void* func, rtStream_t stream, int gridX, int gridY, int gridZ, int *profilerRegistered, {arg_decls}) {{ + // only 1D parallelization is supported for NPU + // Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride + // base_ptr offset shape and stride are not used, arbitrarily set for now + std::string name = ""; + name.append(kernelName); + if (!(*profilerRegistered)) {{ + MsprofRegisterCallback(8, ProfCtrlHandle); // 8 - CCE defined in msprof headerfile slog.h + *profilerRegistered = 1; + }} + auto launch_call = [=]() {{ + uint32_t blockNum = gridX * gridY * gridZ; + {'TTAscDebug::DebugTunnelData *DTData = TTAscDebug::Open(blockNum);' if enable_device_print else ''} + rtError_t ret; + void *ffts_addr = NULL; + uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); + if (ret != RT_ERROR_NONE) {{ + return ret; + }} + // stub argument for workspace + void *workspace_addr = NULL; + {f''' + uint16_t ModuleId = 0; + uint64_t totalWorkSpaceSize = {workspace_size} * blockNum; + ret = rtMalloc(reinterpret_cast(&workspace_addr), + totalWorkSpaceSize, RT_MEMORY_HBM, ModuleId); + if (ret != RT_ERROR_NONE) {{ + return ret; + }} + ''' if workspace_size > 0 else ''} + struct __attribute__((packed)) {{ + void* ffts_addr __attribute__((aligned(8))); + void* workspace_addr __attribute__((aligned(8))); + {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8})));' for i, ty in signature.items() if i not in constants)} + {' '.join(f'{_ty_to_cpp(ty)} grid{mark} __attribute__((aligned(4)));' for mark, ty in grid_info.items())} + {'void* DTData __attribute__((aligned(8)));' if enable_device_print else ''} + }} args = {{ + static_cast(ffts_addr), + static_cast(workspace_addr), + {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants)}, + {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(grid{mark})' for mark, ty in grid_info.items())} + {', static_cast(DTData)' if enable_device_print else ''} + }}; + unsigned long int beginTime = 0; + unsigned long int endTime = 0; + unsigned long int opName = 0; + unsigned int threadId = 0; + char* kernelName = const_cast(name.c_str()); + size_t length = name.length(); + // FIXME: to avoid bug in msprof, currently we disable these checks + // if (__MsprofFlagL0 || __MsprofFlagL1) {{ + {{ + beginTime = MsprofSysCycleTime(); + }} + ret = rtKernelLaunch(func, blockNum, static_cast(&args), sizeof(args), NULL, stream); + {'TTAscDebug::Close(DTData, stream);' if enable_device_print else ''} + // FIXME: to avoid bug in msprof, currently we disable these checks + // if (__MsprofFlagL0 || __MsprofFlagL1) {{ + {{ + endTime = MsprofSysCycleTime(); + opName = MsprofGetHashId(kernelName, length); + threadId = (unsigned int)(syscall(SYS_gettid)); + MsprofApi info; + info.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM + info.level = 10000; //MSPROF_REPORT_NODE_LEVEL + info.type = 5; //MSPROF_REPORT_NODE_LAUNCH_TYPE + info.threadId = threadId; + info.reserve = 0; + info.beginTime = beginTime; + info.endTime = endTime; + info.itemId = opName; + MsprofReportApi(0, &info); + }} + // FIXME: to avoid bug in msprof, currently we disable these checks + // if (__MsprofFlagL1) {{ + {{ + MsprofCompactInfo nodeBasicInfo; + nodeBasicInfo.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM + nodeBasicInfo.level = 10000; //MSPROF_REPORT_NODE_LEVEL + nodeBasicInfo.type = 0; //MSPROF_REPORT_NODE_BASIC_INFO_TYPE + nodeBasicInfo.threadId = threadId; + nodeBasicInfo.timeStamp = endTime; + nodeBasicInfo.data.nodeBasicInfo.opName = opName; + nodeBasicInfo.data.nodeBasicInfo.taskType = 0; //MSPROF_GE_TASK_TYPE_AI_CORE + nodeBasicInfo.data.nodeBasicInfo.opType = opName; + nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX; + MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo)); + }} + return ret; + }}; + at_npu::native::OpCommand cmd; + cmd.Name(name.c_str()) + .SetCustomHandler(launch_call) + .Run(); +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + rtStream_t stream; + const void *function; + PyObject *packedMetadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + {' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple( + args, \"{format}\", + &gridX, &gridY, &gridZ, &stream, &function, + &packedMetadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook + {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''} + ) + ) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + // get kernel_name + PyObject *kernelNameObj = PyDict_GetItemString(packedMetadata, "kernel_name"); + const char *kernelName = PyUnicode_AsUTF8(kernelNameObj); + PyObject *profilerRegisteredObj = PyDict_GetItemString(packedMetadata, "profiler_registered"); + int profilerRegistered = PyObject_IsTrue(profilerRegisteredObj); + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0]=="*" else "" for i, ty in signature.items()])}; + _launch(kernelName, function, stream, gridX, gridY, gridZ, &profilerRegistered, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); + if (PyErr_Occurred()) {{ + return NULL; + }} + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + + return Py_BuildValue("I", profilerRegistered); +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" diff --git a/third_party/ascend/backend/name.conf b/third_party/ascend/backend/name.conf new file mode 100644 index 000000000..3fd20dbae --- /dev/null +++ b/third_party/ascend/backend/name.conf @@ -0,0 +1 @@ +huawei diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp new file mode 100644 index 000000000..bcfc61c50 --- /dev/null +++ b/third_party/ascend/backend/npu_utils.cpp @@ -0,0 +1,136 @@ +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include +#include + +#include "experiment/runtime/runtime/rt.h" + +// Use map to differentiate same name functions from different binary +static std::unordered_map registered_names; +static std::unordered_map> func_stubs; + +static std::tuple +registerKernel(const char *name, const void *data, size_t data_size, int shared, + int device, const char *kernel_mode_str) { + rtError_t rtRet; + + rtDevBinary_t devbin; + devbin.data = data; + devbin.length = data_size; + const std::string kernel_mode{kernel_mode_str}; + if (kernel_mode == "aiv") + devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + else + devbin.magic = RT_DEV_BINARY_MAGIC_ELF; + devbin.version = 0; + + rtRet = rtSetDevice(device); + if (rtRet != RT_ERROR_NONE) { + printf("rtSetDevice failed, 0x%x\n", rtRet); + return {NULL, NULL}; + } + + void *devbinHandle = NULL; + rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); + if (rtRet != RT_ERROR_NONE) { + printf("rtDevBinaryRegister failed, 0x%x\n", rtRet); + return {NULL, NULL}; + } + + std::string stubName = name; + stubName += "_" + std::to_string(registered_names[name]); + registered_names[name]++; + auto registered = func_stubs.emplace(stubName, std::make_unique(0)); + void *func_stub_handle = registered.first->second.get(); + rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(), + (void *)name, 0); + if (rtRet != RT_ERROR_NONE) { + printf("rtFunctionRegister failed(stubName = %s), 0x%x\n", stubName.c_str(), + rtRet); + exit(1); + return {NULL, NULL}; + } + + return std::make_tuple(devbinHandle, func_stub_handle); +} + +static PyObject *loadKernelBinary(PyObject *self, PyObject *args) { + const char *name; // kernel name + const char *data; // binary pointer + Py_ssize_t data_size; // binary size + int shared; // shared_memory(meaningless now) + int device; // device ID + const char *kernel_mode; // kernel mode + + if (!PyArg_ParseTuple(args, "ss#iis", &name, &data, &data_size, &shared, + &device, &kernel_mode)) { + return NULL; + } + + auto [module_handle, func_handle] = + registerKernel(name, data, data_size, shared, device, kernel_mode); + + uint64_t mod = reinterpret_cast(module_handle); + uint64_t func = reinterpret_cast(func_handle); + if (PyErr_Occurred()) { + return NULL; + } + + return Py_BuildValue("(KKii)", mod, func, 0, 0); +} + +static PyObject *getArch(PyObject *self, PyObject *args) { + char name[64] = {'\0'}; + + rtError_t rtRet = rtGetSocVersion(name, 64); + + if (rtRet != RT_ERROR_NONE) { + printf("rtGetSocVersion failed, 0x%x", rtRet); + return NULL; + } + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("s", name); +} + +static PyObject *getAiCoreNum(PyObject *self, PyObject *args) { + uint32_t aiCoreCnt; + + rtError_t rtRet = rtGetAiCoreCount(&aiCoreCnt); + + if (rtRet != RT_ERROR_NONE) { + printf("rtGetAiCoreCount failed, 0x%x", rtRet); + return NULL; + } + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("I", aiCoreCnt); +} + +static PyMethodDef NpuUtilsMethods[] = { + {"load_kernel_binary", loadKernelBinary, METH_VARARGS, + "Load NPU kernel binary into NPU driver"}, + {"get_arch", getArch, METH_VARARGS, "Get soc version of NPU"}, + // sentinel + {"get_aicore_num", getAiCoreNum, METH_VARARGS, "Get the number of AI core"}, + {NULL, NULL, 0, NULL}}; + +static PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, "npu_utils", + "Utilities for fetching NPU device info and preparing kernel binary", -1, + NpuUtilsMethods}; + +PyMODINIT_FUNC PyInit_npu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, NpuUtilsMethods); + return m; +} diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py new file mode 100644 index 000000000..826ae3a8c --- /dev/null +++ b/third_party/ascend/backend/utils.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import re +import os +from pathlib import Path +import functools +import sysconfig +import shutil +import subprocess + +TRITON_PROFILER_REGISTERED = False + + +def downgrade_llir(llir): + llir = _downgrade_mem_attrs(llir) + llir = _downgrade_stacksaverestore_intrinsics(llir) + return llir + + +def _downgrade_mem_attrs(llir: str): + memory_pattern = r"memory\([^()]*\)" + + def replace_mem_attr(m): + attrs = m[0][7:-1].split(",") + if len(attrs) == 0: + return "readnone" + loc_map = {"argmem": 1, "inaccessiblemem": 2, "other": 4} + loc_attr = 0 + rw_map = {"readwrite": 3, "write": 2, "read": 1, "none": 0} + rw_attr = 0 + for attr_pair in attrs: + pair = attr_pair.split(":") + assert len(pair) <= 2 + if len(pair) == 1: + rw = rw_map[pair[0].strip()] + loc = loc_map["other"] # all location + else: + rw = rw_map[pair[1].strip()] + loc_str = pair[0].strip() + if loc_str == "argmem" or loc_str == "inaccessiblemem": + loc = loc_map[loc_str] + else: + loc = loc_map["other"] + if rw > 0: + loc_attr = loc_attr | loc + rw_attr = rw_attr | rw + rev_rw_map = {0: "readnone", 1: "readonly", 2: "writeonly"} + if rw_attr in rev_rw_map: + rw_attr_str = rev_rw_map[rw_attr] + else: + rw_attr_str = "" + rev_loc_map = {1: "argmemonly", 2: "inaccessiblememonly", 3: "inaccessiblemem_or_argmemonly"} + if loc_attr in rev_loc_map: + loc_attr_str = rev_loc_map[loc_attr] + else: + loc_attr_str = "" + return rw_attr_str + " " + loc_attr_str + + return re.sub(memory_pattern, replace_mem_attr, llir) + + +def _downgrade_stacksaverestore_intrinsics(llir: str): + llir = re.sub(r"llvm\.stacksave\.\w+", "llvm.stacksave", llir) + llir = re.sub(r"llvm\.stackrestore\.\w+", "llvm.stackrestore", llir) + return llir + + +def _get_triton_adapter_opt_path() -> str: + path = os.path.dirname(__file__) + path = os.path.join(path, "triton-adapter-opt") + return path + + +def _get_mlir_path(path: str, *paths) -> str: + root_path = os.getenv("MLIR_ROOT", "") + if root_path == "": + raise EnvironmentError("MLIR_ROOT is not set.") + return os.path.join(root_path, path, *paths) + + +def _get_llvm_path(path: str, *paths) -> str: + root_path = os.getenv("LLVM_ROOT", "") + if root_path == "": + raise EnvironmentError("LLVM_ROOT is not set.") + return os.path.join(root_path, path, *paths) + + +def _get_npucompiler_path() -> str: + npu_compiler_path = shutil.which("bishengir-compile") + if npu_compiler_path is None: + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") + if npu_compiler_root is None: + raise EnvironmentError("Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH.") + npu_compiler_path = os.path.join(npu_compiler_root, "npuc") + return npu_compiler_path + + +def _get_bisheng_path() -> str: + bisheng_path = shutil.which("bisheng") + if bisheng_path is None: + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") + if npu_compiler_root is None: + raise EnvironmentError("Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH") + bisheng_path = os.path.join(npu_compiler_root, "ccec") + return bisheng_path + + +@functools.lru_cache(None) +def _get_ascend_path() -> str: + path = os.getenv("ASCEND_HOME_PATH", "") + if path == "": + raise EnvironmentError("ASCEND_HOME_PATH is not set, source /set_env.sh first") + return Path(path) + + +def _is_ascend_sanitizer_enabled() -> bool: + return os.getenv("TRITON_ENABLE_SANITIZER", 'false').lower() in ('true', '1') + + +def _build_npu_ext(obj_name: str, src_path, src_dir, *, kernel_launcher=None) -> str: + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so_path = os.path.join(src_dir, f"{obj_name}{suffix}") + + cxx = os.environ.get("CC") + if cxx is None: + clangxx = shutil.which("clang++") + gxx = shutil.which("g++") + cxx = clangxx if clangxx is not None else gxx + if cxx is None: + raise RuntimeError("Failed to find C++ compiler") + cc_cmd = [cxx, src_path] + # disable all warnings + cc_cmd += [f"-w"] + # find the python library + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + cc_cmd += [f"-I{py_include_dir}"] + # device_print.h + cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"] + # find the ascend library + asc_path = _get_ascend_path() + cc_cmd += [ + f"-I{os.path.join(asc_path, 'include')}", + f"-I{os.path.join(asc_path, 'include/experiment')}", + f"-I{os.path.join(asc_path, 'include/experiment/msprof')}", + f"-L{os.path.join(asc_path, 'lib64')}", + "-lruntime", + "-lascendcl", + ] + + if kernel_launcher == "torch": + import torch + import torch_npu + torch_path = os.path.dirname(os.path.realpath(torch.__file__)) + torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) + use_cxx11_abi = _check_cxx11_abi() + cc_cmd += [ + f"-I{os.path.join(torch_path, 'include')}", + f"-I{os.path.join(torch_npu_path, 'include')}", + f"-L{os.path.join(torch_npu_path, 'lib')}", + "-ltorch_npu", + f"-D_GLIBCXX_USE_CXX11_ABI={use_cxx11_abi}", + ] + + cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", so_path] + + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so_path + else: + raise RuntimeError("Failed to compile " + src_path) + + +def _get_kernel_target(metadata: dict): + if "target" not in metadata: + raise Exception("No target provided!") + sub_target = metadata["target"].arch + assert isinstance(sub_target, str) + if sub_target.startswith('Ascend910B'): + mix_mode = metadata["mix_mode"] + if mix_mode.lower().strip("_").startswith("aiv"): + return "ascend_910b_vec", "c220-vec", "aiv" + elif mix_mode.lower().strip("_").startswith("aic"): + return "ascend_910b_cube", "c220-cube", "aic" + else: + return "ascend_910b", "c220", "mix" + elif sub_target.startswith('Ascend910'): + return "ascend_910", "c100", "mix" + else: + raise NotImplementedError(f"NPU subtarget {sub_target} not supported yet") + + +def _check_cxx11_abi(): + import torch + return 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 diff --git a/third_party/ascend/language/ascend/__init__.py b/third_party/ascend/language/ascend/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/ascend/language/ascend/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/ascend/libdevice.py new file mode 100644 index 000000000..db22bf7cc --- /dev/null +++ b/third_party/ascend/language/ascend/libdevice.py @@ -0,0 +1,135 @@ +from triton.language import core + + +@core.extern +def reciprocal(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_recipf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_recipDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_log1pf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_log1pDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def relu(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_reluf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_reluDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_isinf", core.dtype("int1")), + (core.dtype("fp16"), ): ("__hmf_isinf", core.dtype("int1")), + (core.dtype("bf16"), ): ("__hmf_isinf", core.dtype("int1")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_tanDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_atanf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_atanDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_ilogbf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_ilogbDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_ldexpf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_ldexpDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powf", core.dtype("fp16")), + (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powf", core.dtype("bf16")), + (core.dtype("int64"), core.dtype("int64")): ("__hmf_powi", core.dtype("int64")), + (core.dtype("int32"), core.dtype("int32")): ("__hmf_powi", core.dtype("int32")), + (core.dtype("int16"), core.dtype("int16")): ("__hmf_powi", core.dtype("int16")), + (core.dtype("int8"), core.dtype("int8")): ("__hmf_powi", core.dtype("int8")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_isnan", core.dtype("int1")), + (core.dtype("fp16"), ): ("__hmf_isnan", core.dtype("int1")), + (core.dtype("bf16"), ): ("__hmf_isnan", core.dtype("int1")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def flip(arg0, arg1=None, _builder=None): + if arg1 == None: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("bf16"), ): ("__hmf_flipDhb", core.dtype("bf16")), + (core.dtype("fp16"), ): ("__hmf_flipDh", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_flipf", core.dtype("fp32")), + (core.dtype("int8"), ): ("__hmf_flipi8", core.dtype("int8")), + (core.dtype("int16"), ): ("__hmf_flipi16", core.dtype("int16")), + (core.dtype("int32"), ): ("__hmf_flipi32", core.dtype("int32")), + (core.dtype("uint32"), ): ("__hmf_flipui32", core.dtype("uint32")), + (core.dtype("int64"), ): ("__hmf_flipi64", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("bf16"), core.dtype("int32")): ("__hmf_flipDhb", core.dtype("bf16")), + (core.dtype("fp16"), core.dtype("int32")): ("__hmf_flipDh", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("int32")): ("__hmf_flipf", core.dtype("fp32")), + (core.dtype("int8"), core.dtype("int32")): ("__hmf_flipi8", core.dtype("int8")), + (core.dtype("int16"), core.dtype("int32")): ("__hmf_flipi16", core.dtype("int16")), + (core.dtype("int32"), core.dtype("int32")): ("__hmf_flipi32", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("int32")): ("__hmf_flipui32", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int32")): ("__hmf_flipi64", core.dtype("int64")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/ascend/triton-adapter/CMakeLists.txt b/third_party/ascend/triton-adapter/CMakeLists.txt new file mode 100644 index 000000000..fcc2348b7 --- /dev/null +++ b/third_party/ascend/triton-adapter/CMakeLists.txt @@ -0,0 +1,14 @@ +option(TRITON_ADAPTER_BUILD_CPU_BACKEND "Build triton-adapter CPU backend" ON) + +set(TRITON_ADAPTER_SOURCE_DIR ".") +set(TRITON_ADAPTER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") + +include_directories(./include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(tools) + +if (TRITON_ADAPTER_BUILD_CPU_BACKEND) + add_triton_plugin(TritonAdapter triton_adapter.cc LINK_LIBS TritonToLinalg) +endif() diff --git a/third_party/ascend/triton-adapter/include/CMakeLists.txt b/third_party/ascend/triton-adapter/include/CMakeLists.txt new file mode 100644 index 000000000..64ac15761 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonToLinalg) diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h new file mode 100644 index 000000000..0bf121049 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h @@ -0,0 +1,317 @@ +#ifndef TRITON_ADAPTER_ARGMINMAXCONVERTER_H +#define TRITON_ADAPTER_ARGMINMAXCONVERTER_H + +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "ConversionPatterns.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "triton-to-linalg" + +#include "llvm/Support/Debug.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +template +class ArgMinMaxBaseConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchTieBreakResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &tileBreakValue) const { + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto eqCmpOp = dyn_cast(*it); + if (eqCmpOp) { + if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ || + currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { + return failure(); + } + } + + auto eqCmpIOp = dyn_cast(*it++); + if (eqCmpIOp) { + if (eqCmpIOp.getPredicate() != arith::CmpIPredicate::eq || + currValue != eqCmpIOp.getLhs() || reduceValue != eqCmpIOp.getRhs()) { + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto sltCmpOp = dyn_cast(*it++); + if (!sltCmpOp || sltCmpOp.getPredicate() != arith::CmpIPredicate::slt || + currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { + return failure(); + } + + // matching: %13 = arith.andi %11, %12 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto andOp = dyn_cast(*it++); + + Value cmpOp; + if (eqCmpOp) + cmpOp = eqCmpOp; + else + cmpOp = eqCmpIOp; + + if (!andOp || andOp.getLhs() != cmpOp || andOp.getRhs() != sltCmpOp) { + return failure(); + } + + tileBreakValue = andOp; + return success(); + } + + LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &shouldUpdate) const { + Value tieResult; + if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, + reduceIndex, it, tieResult))) { + LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); + return failure(); + } + + Value comparisonResult; + if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, + reduceIndex, it, comparisonResult))) { + LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto orOp = dyn_cast(*it++); + if (!orOp || orOp.getLhs() != comparisonResult || + orOp.getRhs() != tieResult) { + return failure(); + } + + shouldUpdate = orOp; + return success(); + } + + Value getInitTensor(ConversionPatternRewriter &rewriter, + ArrayRef shape, Value fillValue, + Location loc) const { + Value initTensor = + rewriter.create(loc, shape, fillValue.getType()); + return rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{initTensor}) + .result(); + } + +public: + ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult match(triton::ReduceOp op) const override final { + if (op.getBody()->getNumArguments() != 4) { + return failure(); + } + + auto block = op.getBody(); + auto ops = block->without_terminator(); + + Value currValue = block->getArgument(0); + Value currIndex = block->getArgument(1); + Value reduceValue = block->getArgument(2); + Value reduceIndex = block->getArgument(3); + + auto opsIt = ops.begin(); + Value shouldUpdate; + if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, + reduceIndex, opsIt, shouldUpdate))) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto valueSelectOp = dyn_cast(*opsIt++); + if (!valueSelectOp || valueSelectOp.getCondition() != shouldUpdate || + currValue != valueSelectOp.getTrueValue() || + reduceValue != valueSelectOp.getFalseValue()) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto indexSelectOp = dyn_cast(*opsIt++); + if (indexSelectOp) { + if (indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + if (!indexSelectOp || indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto termOp = dyn_cast(*opsIt++); + if (!(termOp && termOp == block->getTerminator() && + termOp.getOperands() == + ArrayRef{valueSelectOp, indexSelectOp})) { + return failure(); + } + return success(); + } + + void rewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto loc = op.getLoc(); + auto elemTypes = op.getElementTypes(); + + auto valueType = elemTypes[0]; + // tl.argmin reorder + auto block = op.getBody(); + if (isa(valueType)) { + arith::CmpFOp cmpFOp; + block->walk([&](arith::CmpFOp cmpOp) { + auto pred = cmpOp.getPredicate(); + if (pred == arith::CmpFPredicate::OEQ || + pred == arith::CmpFPredicate::ONE || + pred == arith::CmpFPredicate::UEQ || + pred == arith::CmpFPredicate::UNE) { + return WalkResult::advance(); + } else if (pred == arith::CmpFPredicate::OGT || + pred == arith::CmpFPredicate::OLT || + pred == arith::CmpFPredicate::UGT || + pred == arith::CmpFPredicate::ULT) { + cmpFOp = cmpOp; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + cmpFOp->moveBefore(block, block->getOperations().begin()); + } else if (isa(valueType)) { + arith::CmpIOp cmpIOp; + block->walk([&](arith::CmpIOp cmpOp) { + auto pred = cmpOp.getPredicate(); + if (pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) { + return WalkResult::advance(); + } else if (pred == arith::CmpIPredicate::sgt || + pred == arith::CmpIPredicate::slt || + pred == arith::CmpIPredicate::ugt || + pred == arith::CmpIPredicate::ult) { + if (cmpOp.getLhs() == block->getArgument(0) && + cmpOp.getRhs() == block->getArgument(2)) { + cmpIOp = cmpOp; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + cmpIOp->moveBefore(block, block->getOperations().begin()); + } + + TypedAttr valueAttr; + if (isa(valueType)) { + valueAttr = rewriter.getFloatAttr(valueType, T::getBaseReductionValue()); + } else if (isa(valueType)) { + // TODO: support other type of int + valueAttr = + rewriter.getIntegerAttr(valueType, T::getBaseReductionIntValue()); + } + + auto valuesAccBaseVal = + rewriter.create(loc, valueType, valueAttr); + + auto indexType = elemTypes[1]; + auto indicesAccBaseVal = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + SmallVector reductionResultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(valueResultType.getShape())}; + + SmallVector outputs{ + getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), + getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + assert(inputs.size() == 4); + + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { + return mapping.lookup(val); + }); + b.create(loc, results); + }); + + // before we rewrite the argmax reduce op, we know it has return value + // so addReduceWithIndexAttrIfNeeded won't fail + // but ignoring it will lead to compiling failure + auto logicalResult = addReduceWithIndexAttrIfNeeded(rewriter, linalgOp); + + if (isScalarReduce) { + SmallVector reduceResults{ + rewriter.create( + loc, valueType, linalgOp.getResults()[0], ValueRange{}), + rewriter.create( + loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + } +}; + +class ArgMinConverter : public ArgMinMaxBaseConverter { +public: + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult); + + static float getBaseReductionValue(); + + static int8_t getBaseReductionIntValue(); + + ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +class ArgMaxConverter : public ArgMinMaxBaseConverter { +public: + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult); + + static float getBaseReductionValue(); + + static int8_t getBaseReductionIntValue(); + + ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +} // namespace TTOpConverters + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h new file mode 100644 index 000000000..c3ac76a2e --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h @@ -0,0 +1,237 @@ +#ifndef TRITON_ANALYSIS_BLOCKPTRANALYSIS_H +#define TRITON_ANALYSIS_BLOCKPTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +namespace mlir { + +class ConversionPatternRewriter; + +namespace triton { + +enum class MemAccVal { Undefined = 0, StrucMemAcc = 1, UnstrucMemAcc = 2 }; + +struct MemAccType { + + MemAccVal value; + + explicit constexpr MemAccType(MemAccVal v = MemAccVal::Undefined) + : value(v) {} + + constexpr operator MemAccVal() const { return value; } + explicit operator bool() = delete; + + constexpr bool isUndefined() const { return value == MemAccVal::Undefined; } + constexpr bool isStructured() const { + return value == MemAccVal::StrucMemAcc; + } + constexpr bool isUnstructured() const { + return value == MemAccVal::UnstrucMemAcc; + } + + void merge(MemAccType &other) { + this->value = (this->value > other.value) ? this->value : other.value; + } + + std::string_view toString() const { + static constexpr std::string_view names[] = {"Undefined", "StrucMemAcc", + "UnstrucMemAcc"}; + return names[static_cast(value)]; + } +}; + +class BlockData { +public: + SmallVector &getOffsetsRef(); + SmallVector &getSizesRef(); + SmallVector &getStridesRef(); + Value &getSourceRef(); + Value &getScalarRef(); + Type &getResElemTyRef(); + MemAccType &getMemAccTypeRef(); + + SmallVector getOffsets() const; + SmallVector getSizes() const; + SmallVector getStrides() const; + Type getResElemTy() const; + OpFoldResult getOffset(int) const; + OpFoldResult getSize(int) const; + OpFoldResult getStride(int) const; + Value getScalar() const; + Value getSource() const; + MemAccType getMemAccType() const; + + bool isScalar() const; + bool isEmpty() const; + bool hasSource() const; + bool hasResElemTy() const; + void removeSource(); + + int64_t getRank() const; + MemRefType getResultMemrefType(int64_t offset, ArrayRef resultShape, + bool DynamicStrides = false) const; + + void addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + void mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + void divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + + memref::ReinterpretCastOp createCastOp(ArrayRef resultShape, + const Location &loc, + OpBuilder &builder) const; + + void setResElemTy(const Type &); + void setSource(const Value &); + void setScalar(const Value &); + void setOffsets(const SmallVector &); + void setStrides(const SmallVector &); + void setSizes(const SmallVector &); + void setMemAccTy(const MemAccType &); + void setMemAccVal(const MemAccVal); + + void dump() const; + +private: + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + Value source; + Value scalar; + Type resElemTy; + MemAccType memAccTy; + + OpFoldResult inferBlockOffset(const Location &loc, OpBuilder &builder) const; +}; + +class BlockDataParser { +public: + using IndexMapSet = std::map>; + + static Value getScalarMemRef(Value ptr, Value memref, const Location &loc, + ConversionPatternRewriter &rewriter); + + static void parse(Value operand, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseAdd(arith::AddIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseMul(arith::MulIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseDiv(arith::DivSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseRem(arith::RemSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseUnrealizedCast(UnrealizedConversionCastOp op, BlockData &data, + const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseMakeRange(triton::MakeRangeOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseExpandDims(triton::ExpandDimsOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseBitcast(triton::BitcastOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseExtSI(arith::ExtSIOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseBroadcast(triton::BroadcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseSplat(triton::SplatOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseConstSplat(arith::ConstantOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseMakeTensorPtr(triton::MakeTensorPtrOp op, BlockData &data, + const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseAddPtr(triton::AddPtrOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseReinterpretCast(memref::ReinterpretCastOp op, BlockData &data, + const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseReduce(triton::ReduceOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void rewriteAddPtr(triton::AddPtrOp op, + triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void rewriteAdvanceOp(triton::AdvanceOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void + rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &known); + + static void rewriteForOp(scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &known); + + static void rewriteAddPtrToUnstrucMemAcc(triton::AddPtrOp op, + triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + BlockData &data); +}; + +template +void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..d62f670bb --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) +add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h new file mode 100644 index 000000000..b1c4c7601 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h @@ -0,0 +1,111 @@ +#ifndef CONVERSIONPATTERNS_H +#define CONVERSIONPATTERNS_H + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include + +using namespace mlir; +using namespace triton; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +static SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +// for IntLike and FloatLike types +static std::optional getBitWidth(Type a) { + if (auto type = dyn_cast(a)) { + auto elementType = type.getElementType(); + if (elementType.isIntOrFloat()) { + return type.getElementType().getIntOrFloatBitWidth(); + } + return std::nullopt; + } + + if (a.isIntOrFloat()) { + return a.getIntOrFloatBitWidth(); + } + return std::nullopt; +} +#endif // CONVERSIONPATTERNS_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h new file mode 100644 index 000000000..33166ea0e --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h @@ -0,0 +1,38 @@ +#ifndef TRITON_ADAPTER_FUNCTIONCONVERTER_H +#define TRITON_ADAPTER_FUNCTIONCONVERTER_H + +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace FunctionConverter { +using namespace mlir; +using namespace triton; + +class GetProgramIDConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class GetNumProgramsConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace FunctionConverter +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h new file mode 100644 index 000000000..3a2701af0 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h @@ -0,0 +1,196 @@ +#ifndef TRITON_ADAPTER_LOADSTORECONVERTER_H +#define TRITON_ADAPTER_LOADSTORECONVERTER_H + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace LoadStoreConverter { + +using namespace mlir; +using namespace triton; + +class AddPtrConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class LoadConverter : public OpConversionPattern { +private: + LogicalResult toTensorAndReplace(triton::LoadOp &op, + RankedTensorType &tensorType, + memref::AllocOp &allocOp, + const Location &loc, + ConversionPatternRewriter &rewriter) const; + + LogicalResult checkModifiedByAddPtrConverter(triton::LoadOp &op) const; + + LogicalResult + continueModifyFromAddPtrConverter(triton::LoadOp &op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + +public: + explicit LoadConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +// tempate class's impl must in header file +template +class LoadStoreCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Value ptrVal = op.getPtr(); + Type ptrTy = ptrVal.getType(); + auto ptrDefOp = ptrVal.getDefiningOp(); + if (isa(ptrVal)) + return failure(); + + if (!isTensorPointerType(ptrTy) && + !isa_and_nonnull(ptrDefOp)) { + if (isa(ptrDefOp)) { + auto castOp = cast(ptrDefOp); + auto castSrc = castOp.getSrc(); + if (!isa(castSrc)) { + auto castSrcDefOp = castSrc.getDefiningOp(); + if (isa(castSrcDefOp)) { + return rewriter.notifyMatchFailure( + op, "BitcastCanonicalizer handles addptr->bitcast->load!"); + } + } + } + + Type zeroTy = getI32SameShape(ptrTy); + Value zeroVal = + createScalarOrSplatConstant(rewriter, op.getLoc(), zeroTy, 0); + Value addptrVal = rewriter.create(op.getLoc(), ptrTy, + ptrVal, zeroVal); + rewriter.modifyOpInPlace( + op, [&]() { op->replaceUsesOfWith(ptrVal, addptrVal); }); + return success(); + } + return failure(); + } +}; + +class ScalarStoreCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override; +}; + +class StoreConverter : public OpConversionPattern { +public: + explicit StoreConverter(MLIRContext *context); + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ScalarAtomicRMWCanonicalizer + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +class AtomicRMWConverter : public OpConversionPattern { +private: + Value createAtomicBinaryOps(OpBuilder &builder, Location loc, + triton::AtomicRMWOp op, Type elementType, + Value lhs, Value rhs) const { + auto rmwOp = op.getAtomicRmwOp(); + + // it has been confirmed in AtomicRMWConverter::matchAndRewrite + // that the ptr of op is of MemRefType + Value binaryOp; + if (rmwOp == triton::RMWOp::FADD) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::ADD) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::XOR) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::OR) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::AND) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::MAX) { + // Max/Min only support f32/i32 for now + // Other type is not supported because of semantic.py + if (isa(elementType)) { + binaryOp = builder.create(loc, lhs, rhs); + } else { + binaryOp = builder.create(loc, lhs, rhs); + } + } else if (rmwOp == triton::RMWOp::MIN) { + if (isa(elementType)) { + binaryOp = builder.create(loc, lhs, rhs); + } else { + binaryOp = builder.create(loc, lhs, rhs); + } + } else { + op.emitOpError("unsupported atomic RMW operation: "); + llvm_unreachable( + "Not implemented. Support fadd, add, max, min for now !"); + } + return binaryOp; + } + + // used when handling scalar + // to verify whether we need to handle this scalar + bool isConstantMaskTrue(Value mask) const { + if (auto denseAttr = + mask.getDefiningOp()->getAttrOfType("value")) { + auto eleType = denseAttr.getType().getElementType(); + if (isa(eleType) && + cast(eleType).getWidth() == 1) { + auto values = denseAttr.getValues(); + return values[0]; + } + } + return false; + } + + DenseSet softwareAtomicKinds = { + triton::RMWOp::AND, triton::RMWOp::OR, triton::RMWOp::XOR}; + +public: + explicit AtomicRMWConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class AtomicMaxMinCanonicalizer : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace LoadStoreConverter +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h new file mode 100644 index 000000000..5df57dbbc --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h @@ -0,0 +1,133 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_MASKANALYSIS_H +#define TRITON_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +// this class helps build Operations +class OpBuilder; + +namespace triton { +// use to decode the pattern in a mask used for load and store + +class MaskState { +public: + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + SmallVector offsets; + OpFoldResult scalar; + + int64_t getRank() const { + assert(dims.size() == offsets.size() && "dims and offsets rank mismatch!"); + return dims.size(); + } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { + return !start && !end && !scalar && dims.size() != 0 && offsets.size() != 0; + } + + // parse value recursively + LogicalResult parse(Value operand, const Location &loc, OpBuilder &builder); + + tensor::ExtractSliceOp getExtractSlice(Value source, const Location &loc, + OpBuilder &builder) const; + + tensor::InsertSliceOp getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const; + + memref::SubViewOp getSubview(Value source, const Location &loc, + OpBuilder &builder) const; + + std::pair + getSideBySideSubviews(Value block1, Value block2, const Location &loc, + OpBuilder &builder) const; + + std::pair + getStackedSubviews(Value block1, Value block2, const Location &loc, + OpBuilder &builder) const; + + void eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter); + +private: + // Utility functions + LogicalResult addStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + LogicalResult divStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult divStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + // Helper functions to parse values to populate MaskState + + LogicalResult parseConstant(arith::ConstantOp constOp, const Location &loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder); + + // TODO + LogicalResult parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder); + + // operand is the result of divsi + LogicalResult parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of andi + LogicalResult parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of cmpi + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of make_range + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of broadcast + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, OpBuilder &builder); + + // Operand is the result of splat + LogicalResult parseSplat(triton::SplatOp splatOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, OpBuilder &builder); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h new file mode 100644 index 000000000..d3623a5cb --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H + +#include "TritonToLinalgPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "TritonToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td new file mode 100644 index 000000000..6ae983b57 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td @@ -0,0 +1,19 @@ +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES +#define TRITON_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton to Linalg dialect"; + let constructor = "triton::createTritonToLinalgPass()"; + let options = [ + Option<"globalKernel", "global-kernel", + "bool", /*default*/"true", + "generate a global kernel">, + Option<"namedOps", "named-ops", + "bool", /*default*/"false", + "use linalg named ops instead of linalg.generic"> + ]; +} + +#endif // TRITON_TO_LINALG_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h new file mode 100644 index 000000000..b078e4ff4 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h @@ -0,0 +1,378 @@ +#ifndef TRITON_ADAPTER_TRITONOPCONVERTER_H +#define TRITON_ADAPTER_TRITONOPCONVERTER_H + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "triton-to-linalg" + +#include "llvm/Support/Debug.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +/* +Convert `tt.precise_div` operation to `arith.divf` operation. +tensor_x / tensor_y + +```ttir + %11 = tt.precise_divf %7, %10 : tensor<100xf32> +``` + +converts to: + +```mlir + %11 = arith.divf %7, %10 : tensor<100xf32> +``` +*/ +struct PreciseDivConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/* + * Rewrite arith.select with contiguouse mask to + * tensor.extract_slice/insert_slice. + */ +class SelectConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override; +}; + +/* + * Move tt.bitcast to a previous location if tt.bitcast is not directly applied + * on function arguments + */ +class BitcastCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::BitcastOp bitcastOp, + PatternRewriter &rewriter) const override; +}; + +template +class ScalarMathCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MathOp op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer expects single scalar output."); + } + if (!op->getResult(0).getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer handles scalar load scene."); + } + if (auto linalgOp = op->template getParentOfType()) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer handles op not within tt.reduce."); + } + auto loc = op.getLoc(); + llvm::SmallVector inputs; + for (auto input : op->getOperands()) { + auto blkTy = RankedTensorType::get({(int64_t)1}, input.getType()); + auto inputSplat = rewriter.create(loc, blkTy, input); + inputs.push_back(inputSplat.getResult()); + } + auto blkOp = rewriter.create(loc, inputs); + Value offset = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto extractOp = + rewriter.create(loc, blkOp.getResult(), offset); + rewriter.replaceOp(op, extractOp); + return success(); + } +}; + +class DenseConstantConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class MakeRangeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class SplatConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ReshapeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ExpandDimsConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ClampFConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class BroadcastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ReduceConverter : public OpConversionPattern { +public: + explicit ReduceConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + +private: + llvm::SmallVector getRedOps(triton::ReduceOp redOp) const; + + bool isReductionOpSupported(Operation *redOp) const; + + arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, + Type constantType) const; + + bool requiresF32Conversion(const Type elemType, Operation *redOp) const; + + Value getRedElement(Value lhs, Value rhs, const Location loc, + Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const; + + LogicalResult + convertToLinalgReduce(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + LogicalResult + convertToLinalgReduceExtended(ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + +public: + LogicalResult + matchAndRewrite(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ExternElementwiseClOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class UnrealizedCastConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class JoinConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class SplitConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class CatConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class GatherConverter : public OpConversionPattern { +private: + static constexpr llvm::StringRef gatherFuncNameBase = "triton_gather"; + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class YieldConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class LoopConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class AdvanceConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class MakeTensorPtrConverter + : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + void populateVectorAsIndex(SmallVector &vec, + Operation::operand_range ops, + ConversionPatternRewriter &rewriter, + Location loc) const; + + memref::ReinterpretCastOp + createRedundantOp(triton::MakeTensorPtrOp op, + ConversionPatternRewriter &rewriter, BlockData &data) const; + + OpFoldResult + accumulatePotentialOffsetOnBase(triton::MakeTensorPtrOp op, Value base, + OpFoldResult offset, + ConversionPatternRewriter &rewriter) const; + +public: + explicit MakeTensorPtrConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TransposeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class BitcastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TritonMulhiuiConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TritonPreciseSqrtConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class AssertConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::AssertOp op, + PatternRewriter &rewriter) const override; +}; + +class DevicePrintConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static constexpr llvm::StringRef printFuncNameBase = "triton_print"; + static constexpr llvm::StringRef prefixAttrName = "prefix"; + static constexpr llvm::StringRef hexAttrName = "hex"; + +public: + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end of namespace TTOpConverters + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h new file mode 100644 index 000000000..6fc57aaad --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H +#define TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define GEN_PASS_CLASSES +#include "../../include/TritonToLinalg/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToLinalgPass(); + +} // namespace triton +} // namespace mlir + +namespace { + +using namespace mlir; +using namespace triton; +const std::string globalKernelAttr = "global_kernel"; +const std::string kernelMixModeName = "mix_mode"; + +class TritonTypeConverter : public mlir::TypeConverter { +public: + explicit TritonTypeConverter(); +}; + +class TritonToLinalgPass : public TritonToLinalgBase { + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + +private: + // grid构造 num_programs 3维, program_id 3维 + // remember 'xxxOp' is usually a Pointer, so that we can change target memory + // without giving a reference argument + void addProgramInfo(triton::FuncOp func, bool globalKernel); + + void convertTTFunc(triton::FuncOp func, const bool existDot); + + void addDynamicLegal(ConversionTarget &target, + TritonTypeConverter &tritonTypeConverter); + + void + populateTritonToLinalgCanonicalizationPatterns(RewritePatternSet &patterns); + + void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + unsigned int launchGridRank); + +public: + void getDependentDialects(DialectRegistry ®istry) const override; + + void runOnOperation() override; +}; +} // namespace + +#endif // TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h new file mode 100644 index 000000000..e2727fa4c --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_USEANALYSIS_H +#define TRITON_ANALYSIS_USEANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +enum class UseType { + Undefined, // Initial state + DataUse, // value used for tensor computation only + MetaUse, // value used for metadata only + MixUse // value used for both tensor computation and metadata +}; + +struct UseInfo : public dataflow::AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UseInfo) + using AbstractSparseLattice::AbstractSparseLattice; + + // Lattice state transfer function + ChangeResult meetUseType(const UseType &other) { + if (other == UseType::Undefined) { + return ChangeResult::NoChange; + } + + switch (type) { + case UseType::Undefined: + type = other; + return ChangeResult::Change; + case UseType::DataUse: + case UseType::MetaUse: + if (type == other) { + return ChangeResult::NoChange; + } else { + type = UseType::MixUse; + return ChangeResult::Change; + } + case UseType::MixUse: + return ChangeResult::NoChange; + default: + llvm_unreachable("bad type"); + } + } + + ChangeResult meet(const AbstractSparseLattice &other) override { + auto rhs = reinterpret_cast(&other); + return meetUseType(rhs->type); + } + + void print(raw_ostream &os) const override { + switch (type) { + case UseType::DataUse: + os << "DataUse"; + break; + case UseType::MetaUse: + os << "MetaUse"; + break; + case UseType::MixUse: + os << "MixUse"; + break; + default: + os << "Undefined"; + } + } + + UseType type = UseType::Undefined; +}; + +class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + +#if LLVM_VERSION_MAJOR >= 20 + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; +#else + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; +#endif + + void visitBranchOperand(OpOperand &operand) override { return; } + + void visitCallOperand(OpOperand &operand) override { return; } + + void setToExitState(UseInfo *lattice) override { + lattice->type = UseType::Undefined; + } + +private: + void propagateUse(UseInfo *lattice, const UseType &type) { + auto changed = lattice->meetUseType(type); + propagateIfChanged(lattice, changed); + } + + void propagateResults(UseInfo *lattice, ArrayRef results) { + auto changed = ChangeResult::NoChange; + for (auto result : results) { + changed |= lattice->meet(*result); + } + propagateIfChanged(lattice, changed); + } +}; + +class MetaUseEraser : public RewritePattern { +public: + MetaUseEraser(MLIRContext *context); + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +LogicalResult runUseAnalysis(triton::FuncOp &funcOp); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOAFFINE_TRITONUSEANALYSIS_H diff --git a/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h b/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h new file mode 100644 index 000000000..8e445d7ee --- /dev/null +++ b/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h @@ -0,0 +1,71 @@ +#pragma once + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include +#include +#include + +namespace mlir { +namespace triton { + +enum class IndexMode : int { EVEN_MODE = 0, ODD_MODE = 1 }; + +MemRefType expandInterleaveMemRefType(MemRefType originType); + +std::pair +recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder); + +LogicalResult +DeinterleaveStatusOptimization(triton::LoadOp op, + triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter); + +LogicalResult DeinterleaveStatusWithMaskOptimization( + triton::LoadOp op, triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter, MaskState &mstate, + memref::AllocOp originAllocOp); + +LogicalResult +InterleaveStatusOptimization(SmallVector materializeVec); + +LogicalResult +InterleaveStatusWithMaskOptimization(SmallVector materializeVec); + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/include/Utils/Utils.h b/third_party/ascend/triton-adapter/include/Utils/Utils.h new file mode 100644 index 000000000..fb713d45f --- /dev/null +++ b/third_party/ascend/triton-adapter/include/Utils/Utils.h @@ -0,0 +1,148 @@ +#ifndef TRITONNPU_UTILS_UTILS_H +#define TRITONNPU_UTILS_UTILS_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" + +#include +#include + +namespace mlir { + +namespace ConverterUtils { + +Value getTransposedValue(Value source, const Location loc, + ConversionPatternRewriter &rewriter, + llvm::ArrayRef order); + +SmallVector getNParallelLoopsAttrs(unsigned n); + +Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter); + +memref::SubViewOp makeSubViewOp(Value src, + llvm::SmallVectorImpl &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter); + +void getShapeInfo(Value val, llvm::SmallVectorImpl &shapes, + ConversionPatternRewriter &rewriter); + +SmallVector +getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, + Value adaptorPtr, const Location &loc, + ConversionPatternRewriter &rewriter); + +SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst); + +SmallVector getUnbroadcastDims(RankedTensorType src, + RankedTensorType dst); + +} // namespace ConverterUtils + +class ConversionPatternRewriter; + +namespace triton { + +enum class IndirectLoadInterfaceOpType { Undefined = 0, Load = 1, Calc = 2 }; + +// Traceback from rootOp to find the targetOp with the specified condition +mlir::Operation * +findFirstMatchingOperandDef(mlir::Operation *rootOp, + const std::function &condFn); + +void traverseBackwardUpdateOperandChainIf( + Operation *op, std::function conditionFn, + std::function actionFn, OpBuilder &builder); + +void traverseBackwardUpdateOperandChainIf( + Operation *rootOp, std::function conditionFn, + std::function actionFn); + +void traverseForwardUpdateUserChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + llvm::SmallPtrSet &stopOps); + +void traverseForwardUpdateUserChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn, + llvm::SmallPtrSet &stopOps); + +// UseAnalysis will tag operations whose results are used only as meta-data +// with "MetaUse" tag. +bool isMetaUse(Operation *op); + +bool isMixUse(Operation *op); + +IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op); + +bool opIsIndirectLoad(Operation *op); + +bool opIsIndirectCalc(Operation *op); + +scf::ForOp createNestedLoops( + OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, + ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, + ValueRange initArgs, + function_ref &, ValueRange)> + bodyBuilder); + +ModuleOp getModuleOpFromOperation(Operation *op); + +} // namespace triton + +class OpBuilder; + +std::optional makeIntAttr(const OpFoldResult &ofr); + +bool hasConstantZero(const OpFoldResult &ofr); + +Value opFoldResultToIndex(const OpFoldResult &ofr, const Location &loc, + OpBuilder &b); + +SmallVector opFoldResultToIndex(ArrayRef ofrs, + const Location &loc, OpBuilder &b); + +Value createConstIntOp(const Location &loc, OpBuilder &b, int64_t value); + +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const Value &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +LogicalResult +addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, + linalg::ReduceOp reduceOp); + +} // namespace mlir + +#endif // TRITONNPU_UTILS_UTILS_H diff --git a/third_party/ascend/triton-adapter/lib/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/CMakeLists.txt new file mode 100644 index 000000000..cbf0d9d7e --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonToLinalg) +add_subdirectory(Utils) diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp new file mode 100644 index 000000000..9bf7f7ca2 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp @@ -0,0 +1,77 @@ +#include "TritonToLinalg/ArgMinMaxConverter.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +// ArgMinConverter functions +LogicalResult ArgMinConverter::matchComparisonResult( + Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, Value &comparisonResult) { + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + + auto cmpOp = dyn_cast(*it); + auto cmpIOp = dyn_cast(*it++); + if (!cmpOp && !cmpIOp) + return failure(); + + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + comparisonResult = cmpOp; + } + + if (cmpIOp) { + if (cmpIOp.getPredicate() != arith::CmpIPredicate::slt || + currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { + return failure(); + } + comparisonResult = cmpIOp; + } + + return success(); +} + +float ArgMinConverter::getBaseReductionValue() { + return std::numeric_limits::infinity(); +} + +int8_t ArgMinConverter::getBaseReductionIntValue() { return 127; } + +// ArgMaxConverter functions +LogicalResult ArgMaxConverter::matchComparisonResult( + Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, Value &comparisonResult) { + auto cmpOp = dyn_cast(*it); + auto cmpIOp = dyn_cast(*it++); + if (!cmpOp && !cmpIOp) + return failure(); + + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + comparisonResult = cmpOp; + } + + if (cmpIOp) { + if (cmpIOp.getPredicate() != arith::CmpIPredicate::sgt || + currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { + return failure(); + } + comparisonResult = cmpIOp; + } + + return success(); +} + +float ArgMaxConverter::getBaseReductionValue() { + return -std::numeric_limits::infinity(); +} + +int8_t ArgMaxConverter::getBaseReductionIntValue() { return -128; } + +} // namespace TTOpConverters diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp new file mode 100644 index 000000000..26b8658e3 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -0,0 +1,1404 @@ +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "Utils/Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include + +#define DEBUG_TYPE "triton-block-ptr-analysis" + +namespace mlir { +namespace triton { + +// MemAccType selectMaxMemAccTy(const MemAccType &v1, const MemAccType &v2) { +// return (v1 > v2) ? v1 : v2; +// } + +namespace { +void assertLegalUnrealizedCast(UnrealizedConversionCastOp op) { + assert(op && op.getInputs().size() == 3 && + op.getInputs()[0].getDefiningOp() && + op.getInputs()[1].getDefiningOp() && + op.getInputs()[1].getDefiningOp()); +} +} // namespace + +SmallVector &BlockData::getOffsetsRef() { return this->offsets; } + +SmallVector &BlockData::getSizesRef() { return this->sizes; } + +SmallVector &BlockData::getStridesRef() { return this->strides; } + +Value &BlockData::getSourceRef() { return this->source; } + +Value &BlockData::getScalarRef() { return this->scalar; } + +SmallVector BlockData::getOffsets() const { + return this->offsets; +} + +SmallVector BlockData::getSizes() const { return this->sizes; } + +SmallVector BlockData::getStrides() const { + return this->strides; +} + +OpFoldResult BlockData::getOffset(int index) const { + return this->offsets[index]; +} + +OpFoldResult BlockData::getSize(int index) const { return this->sizes[index]; } + +OpFoldResult BlockData::getStride(int index) const { + return this->strides[index]; +} + +Value BlockData::getScalar() const { return this->scalar; } + +Value BlockData::getSource() const { return this->source; } + +MemAccType BlockData::getMemAccType() const { return this->memAccTy; }; + +MemAccType &BlockData::getMemAccTypeRef() { return this->memAccTy; }; + +bool BlockData::isScalar() const { return this->scalar != nullptr; } + +bool BlockData::isEmpty() const { + return !(this->getRank() || this->source || this->scalar); +} + +bool BlockData::hasSource() const { return this->source != nullptr; } + +void BlockData::removeSource() { this->source = nullptr; }; + +bool BlockData::hasResElemTy() const { return this->resElemTy != nullptr; } + +Type &BlockData::getResElemTyRef() { return this->resElemTy; } + +Type BlockData::getResElemTy() const { return this->resElemTy; } + +int64_t BlockData::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + return this->offsets.size(); +} + +void BlockData::setResElemTy(const Type &Ty) { this->resElemTy = Ty; } + +void BlockData::setScalar(const Value &scalar) { this->scalar = scalar; } + +void BlockData::setSource(const Value &src) { this->source = src; } + +void BlockData::setOffsets(const SmallVector &offsets) { + this->offsets = offsets; +} + +void BlockData::setStrides(const SmallVector &strides) { + this->strides = strides; +} + +void BlockData::setSizes(const SmallVector &szs) { + this->sizes = szs; +} + +void BlockData::setMemAccTy(const MemAccType &v) { this->memAccTy = v; } + +void BlockData::setMemAccVal(const MemAccVal v) { this->memAccTy.value = v; } + +OpFoldResult BlockData::inferBlockOffset(const Location &loc, + OpBuilder &builder) const { + OpFoldResult retOffset = builder.getIndexAttr(0); + for (auto ofr : offsets) { + retOffset = addOpFoldResult(retOffset, ofr, loc, builder); + } + return retOffset; +} + +MemRefType BlockData::getResultMemrefType(int64_t offset, + ArrayRef resultShape, + bool DynamicStrides) const { + SmallVector staticStrides; + if (DynamicStrides) { + staticStrides.append(this->strides.size(), ShapedType::kDynamic); + } else { + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + } + + auto elementType = + dyn_cast(this->source.getType()).getElementType(); + auto layout = + StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides); + return MemRefType::get(resultShape, elementType, layout); +} + +void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + // When both left block and right block have source, it is indirect load. + assert(!(lBlock.hasSource() && rBlock.hasSource())); + this->source = + lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef(); + + assert(!rBlock.hasResElemTy()); + if (lBlock.hasResElemTy()) { + this->resElemTy = lBlock.getResElemTyRef(); + } + + if (lBlock.isScalar() && rBlock.isScalar()) { + auto addOp = rewriter.create(loc, lBlock.getScalarRef(), + rBlock.getScalarRef()); + this->scalar = addOp.getResult(); + } else if (lBlock.getRank() == 0) { + this->scalar = + lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef(); + } + + for (const auto &[lOffset, rOffset] : + llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter)); + } + + for (const auto &[lStride, rStride] : + llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter)); + } + + this->sizes = lBlock.getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + + assert(!(lBlock.hasSource() && rBlock.hasSource())); + + assert( + (lBlock.isScalar() ^ rBlock.isScalar()) && + "Currently only support one and only one scalar in function mulBlock()"); + BlockData *lb = &lBlock; + BlockData *rb = &rBlock; + if (lb->isScalar()) { + std::swap(lb, rb); + } + + Value rScalar = rb->getScalarRef(); + for (const auto &lOffset : lb->getOffsetsRef()) { + this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter)); + } + + for (const auto &lStride : lb->getStridesRef()) { + this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter)); + } + + this->sizes = lb->getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + + assert(!(lBlock.hasSource() && rBlock.hasSource())); + + for (const auto &[lOffset, rOffset] : + llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(divOpFoldResult(lOffset, rOffset, loc, rewriter)); + } + + for (const auto &[lStride, rStride] : + llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(divOpFoldResult(lStride, rStride, loc, rewriter)); + } + + this->sizes = lBlock.getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef resultShape, + const Location &loc, + OpBuilder &builder) const { + OpFoldResult resultOffset = this->inferBlockOffset(loc, builder); + SmallVector staticOffset; + SmallVector dynamicOffset; + dispatchIndexOpFoldResult(resultOffset, dynamicOffset, staticOffset); + + auto resultType = this->getResultMemrefType(staticOffset[0], resultShape); + + return builder.create( + loc, resultType, this->source, resultOffset, this->sizes, this->strides); +} + +void BlockData::dump() const { + llvm::outs() << "[INFO][BEG] BlockData info\n"; + llvm::outs() << "offsets has " << offsets.size() << " items\n"; + int cnt = 0; + for (auto it = offsets.begin(); it != offsets.end(); it++) { + llvm::outs() << "offsets[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "sizes has " << sizes.size() << " items\n"; + cnt = 0; + for (auto it = sizes.begin(); it != sizes.end(); it++) { + llvm::outs() << "sizes[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "strides has " << strides.size() << " items\n"; + cnt = 0; + for (auto it = strides.begin(); it != strides.end(); it++) { + llvm::outs() << "strides[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "source = " << source << "\n"; + llvm::outs() << "scalar = " << scalar << "\n"; + llvm::outs() << "resElemTy = " << resElemTy << "\n"; + llvm::outs() << "memAccTy = " << memAccTy.toString() << "\n"; + llvm::outs() << "[INFO][END] BlockData info\n"; +} + +Value BlockDataParser::getScalarMemRef(Value ptr, Value memref, + const Location &loc, + ConversionPatternRewriter &rewriter) { + assert(isa(ptr.getType()) && "expect a scalar pointer"); + if (ptr.getDefiningOp()) { + if (auto castOp = memref.getDefiningOp()) { + return castOp.getResult(); + } else { + llvm_unreachable("pointer value is defined by an unexpected op"); + } + } + + assert(isa(ptr) && + "pointer should be produced by addptr or block argument"); + BlockData data; + data.setSource(memref); + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + auto castOp = data.createCastOp(SmallVector(1, 1), loc, rewriter); + return castOp.getResult(); +} + +void BlockDataParser::parse( + Value operand, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + if (known.find(operand) != known.end()) { + return data = known.lookup(operand), void(); + } + + if (isa(operand.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), operand); + return data.setScalar(castOp.getResult()), void(); + } + + if (isa(operand.getType())) { + auto remappedPtr = rewriter.getRemappedValue(operand); + assert(remappedPtr); + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + parseAddPtr(addPtrOp, data, loc, rewriter, known); + } else if (auto makeTensorPtrOp = dyn_cast(op)) { + parseMakeTensorPtr(makeTensorPtrOp, data, loc, rewriter, known); + } else if (auto bitcastOp = dyn_cast(op)) { + parseBitcast(bitcastOp, data, loc, rewriter, known); + } else { + llvm_unreachable("Unexpected operand defining operation,A scalar " + "pointer can only be produced by AddPtrOp or a block"); + } + } else { + data.setSource(remappedPtr); + } + return; + } + + // not a scalar pointer + if (auto addOp = operand.getDefiningOp()) { + parseAdd(addOp, data, loc, rewriter, known); + } else if (auto mulOp = operand.getDefiningOp()) { + parseMul(mulOp, data, loc, rewriter, known); + } else if (auto addPtrOp = operand.getDefiningOp()) { + parseAddPtr(addPtrOp, data, loc, rewriter, known); + } else if (auto constOp = operand.getDefiningOp()) { + parseConstSplat(constOp, data, loc, rewriter, known); + } else if (auto broadcastOp = operand.getDefiningOp()) { + parseBroadcast(broadcastOp, data, loc, rewriter, known); + } else if (auto splatOp = operand.getDefiningOp()) { + parseSplat(splatOp, data, loc, rewriter, known); + } else if (auto expandDimsOp = + operand.getDefiningOp()) { + parseExpandDims(expandDimsOp, data, loc, rewriter, known); + } else if (auto remOp = operand.getDefiningOp()) { + parseRem(remOp, data, loc, rewriter, known); + } else if (auto bitcastOp = operand.getDefiningOp()) { + parseBitcast(bitcastOp, data, loc, rewriter, known); + } else if (auto extsiOp = operand.getDefiningOp()) { + parseExtSI(extsiOp, data, loc, rewriter, known); + } else if (auto divOp = operand.getDefiningOp()) { + parseDiv(divOp, data, loc, rewriter, known); + } else if (auto makeRangeOp = operand.getDefiningOp()) { + parseMakeRange(makeRangeOp, data, loc, rewriter, known); + } else if (auto reduceOp = operand.getDefiningOp()) { + parseReduce(reduceOp, data, loc, rewriter, known); + } else if (auto loadOp = operand.getDefiningOp()) { + parseIndirectLoad(loadOp, data, loc, rewriter, known); + } else if (auto castOp = operand.getDefiningOp()) { + parseIndirectLoad(castOp, data, loc, rewriter, known); + } else { + operand.dump(); + llvm_unreachable("encountered AddPtrOp produced by unsupported operation"); + } +} + +void BlockDataParser::parseAdd( + arith::AddIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.addBlock(lBlock, rBlock, loc, rewriter); +} + +void BlockDataParser::parseMul( + arith::MulIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.mulBlock(lBlock, rBlock, loc, rewriter); +} + +void BlockDataParser::parseDiv( + arith::DivSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.divBlock(lBlock, rBlock, loc, rewriter); +} + +// TODO : support modulos +void BlockDataParser::parseRem( + arith::RemSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(false && "Address expression with modulo is not supported yet, it " + "shall be analysis at linearize."); +} + +void BlockDataParser::parseUnrealizedCast( + UnrealizedConversionCastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assertLegalUnrealizedCast(op); + + auto originBlock = op.getInputs()[2]; + if (known.contains(originBlock)) { + data = known.at(originBlock); + } else { + parseAddPtr(originBlock.getDefiningOp(), data, loc, + rewriter, known); + } +} + +void BlockDataParser::parseMakeRange( + triton::MakeRangeOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + auto shape = dyn_cast(op.getType()).getShape(); + + auto start = op.getStart(); + auto end = op.getEnd(); + auto stride = (end >= start) && (end - start <= shape[0]); + assert(stride == 1 && + "make_range op should always return a tensor of stride 1"); + + data.getOffsetsRef().push_back(rewriter.getIndexAttr(start)); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(stride)); +} + +void BlockDataParser::parseExpandDims( + triton::ExpandDimsOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + parse(op.getSrcMutable().get(), data, loc, rewriter, known); + auto resShape = dyn_cast(op.getResult().getType()).getShape(); + auto axis = op.getAxis(); + + assert(resShape[axis] == 1 && + "The destiny shape of changed dimension should be 1"); + + data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis, + rewriter.getIndexAttr(0)); + data.getSizesRef().insert(data.getSizesRef().begin() + axis, + rewriter.getIndexAttr(1)); + data.getStridesRef().insert(data.getStridesRef().begin() + axis, + rewriter.getIndexAttr(0)); +} + +void BlockDataParser::parseBitcast( + triton::BitcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + parse(op.getSrc(), data, loc, rewriter, known); + + auto resType = op.getResult().getType(); + mlir::Type resElemPointeeTy; + if (auto resShapedTy = dyn_cast(resType)) { + auto resElemTy = resShapedTy.getElementType(); + resElemPointeeTy = + dyn_cast(resElemTy).getPointeeType(); + } else { + resElemPointeeTy = dyn_cast(resType).getPointeeType(); + } + data.setResElemTy(resElemPointeeTy); +} + +void BlockDataParser::parseExtSI( + arith::ExtSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + parse(op.getIn(), data, loc, rewriter, known); +} + +void BlockDataParser::parseBroadcast( + triton::BroadcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + auto src = op.getSrcMutable().get(); + auto dst = op.getResult(); + assert(isa(src.getType()) && + "tt.broadcast's input should be a tensor"); + + auto srcShape = dyn_cast(src.getType()).getShape(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source shoule be equal to destnation"); + + parse(src, data, loc, rewriter, known); + + auto &blockSizes = data.getSizesRef(); + for (const auto &[idx, src_dst] : + llvm::enumerate(llvm::zip(srcShape, dstShape))) { + const auto &[srcAxis, dstAxis] = src_dst; + if (srcAxis == dstAxis) { + continue; + } + assert(srcAxis < dstAxis && + "srcShape of broadcastOp must be less than dstShape."); + blockSizes[idx] = rewriter.getIndexAttr(dstAxis); + } +} + +void BlockDataParser::parseSplat( + triton::SplatOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + auto src = op.getSrc(); + auto dst = op.getResult(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + + parse(src, data, loc, rewriter, known); + + if (isa(src.getType()) || + isa(src.getType())) { + if (!data.isEmpty()) { + data.getOffsetsRef().clear(); + data.getSizesRef().clear(); + data.getStridesRef().clear(); + } + for (auto dstAxis : dstShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + } else { + auto srcType = dyn_cast(src.getType()); + assert(srcType.getRank() == 1 && data.getRank() == 1 && + "splat MemRef source should have rank 1"); + assert(srcType.getShape()[0] == 1 && + makeIntAttr(data.getSizesRef()[0]).value() == 1 && + "splat MemRef source shoule have size 1"); + data.getStridesRef()[0] = rewriter.getIndexAttr(0); + + for (const auto &[idx, dstAxis] : llvm::enumerate(dstShape)) { + if (idx == 0) { + data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis); + continue; + } + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + } + if (data.isScalar()) { + data.getOffsetsRef()[0] = data.getScalarRef(); + } +} + +void BlockDataParser::parseConstSplat( + arith::ConstantOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + auto attr = dyn_cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType)); + + auto val = attr.getValues()[0].getValue(); + auto constAttr = rewriter.getIndexAttr(val.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(rewriter, constAttr, + rewriter.getIndexType(), loc); + data.setScalar(constOp); + + auto resType = dyn_cast(op.getResult().getType()); + size_t loopLimit = resType.getShape().size(); + for (auto i = 0; i < loopLimit; i++) { + if (i == 0) { + data.getOffsetsRef().push_back(constOp.getResult()); + } else { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } + data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i])); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } +} + +void BlockDataParser::parseMakeTensorPtr( + triton::MakeTensorPtrOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + auto remappedValue = rewriter.getRemappedValue(op); + if (auto castOp = remappedValue.getDefiningOp()) { + parseReinterpretCast(castOp, data, loc, rewriter, known); + } else { + llvm_unreachable("the value should be mapped to memref.reinterpret_cast"); + } +} + +void BlockDataParser::parseAddPtr( + triton::AddPtrOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + BlockData ptrBlock, offsetBlock; + parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known); + parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known); + + assert(ptrBlock.hasSource() && + "Ptr field should provide source/base pointer"); + // offset has source means offset is from tl.load and other ops(TODO) + if (offsetBlock.hasSource()) { + ptrBlock.setMemAccTy(offsetBlock.getMemAccType()); + offsetBlock.removeSource(); + } + + // handle for loop & scalar + if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) { + offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1)); + offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef()); + offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + + assert(ptrBlock.getRank() == offsetBlock.getRank() && + "ptr and offset should have same rank"); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr][BEG] =========================\n"; + os << "[parseAddPtr] op is " << op << "\n"; + for (int i = 0; i < ptrBlock.getRank(); i++) { + os << "ptrBlock.getOffsetsRef()[" << i + << "] = " << ptrBlock.getOffsetsRef()[i] << "\n"; + os << "ptrBlock.getSizesRef()[" << i + << "] = " << ptrBlock.getSizesRef()[i] << "\n"; + os << "ptrBlock.getStridesRef()[" << i + << "] = " << ptrBlock.getStridesRef()[i] << "\n"; + os << "offsetBlock.getOffsetsRef()[" << i + << "] = " << offsetBlock.getOffsetsRef()[i] << "\n"; + os << "offsetBlock.getSizesRef()[" << i + << "] = " << offsetBlock.getSizesRef()[i] << "\n"; + os << "offsetBlock.getStridesRef()[" << i + << "] = " << offsetBlock.getStridesRef()[i] << "\n"; + } + os << "[parseAddPtr][END] -------------------------\n"; + }); + data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter); +} + +void BlockDataParser::parseReinterpretCast( + memref::ReinterpretCastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + data.setOffsets(op.getMixedOffsets()); + data.setSizes(op.getMixedSizes()); + data.setStrides(op.getMixedStrides()); + data.setSource(op.getSource()); + + assert(data.getOffsetsRef().size() == 1); + size_t loopLimit = data.getSizesRef().size(); + for (size_t i = 1; i < loopLimit; i++) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } + + loopLimit = data.getStridesRef().size(); + for (size_t i = 0; i < loopLimit; i++) { + auto strideIntAttr = makeIntAttr(data.getStridesRef()[i]); + auto sizeIntAttr = makeIntAttr(data.getSizesRef()[i]); + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr) { + data.getStridesRef()[i] = rewriter.getIndexAttr(0); + } + } +} + +void BlockDataParser::parseReduce( + triton::ReduceOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + + const std::string scenarioMessages = + "PtsAnalysis supports indirectly block load in the following scenario\n" + "B = tl.load(Aptr + Aoffset) # B is 1D tensor\n" + "s = tl.min(B) # s is a scalar\n" + "D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n"; + + auto reduce_src = op->getOperand(0); + BlockData srcBlock; + parse(reduce_src, srcBlock, loc, rewriter, known); + if (!srcBlock.hasSource()) { + llvm_unreachable(scenarioMessages.c_str()); + } + if (!isa(srcBlock.getSource().getDefiningOp())) { + llvm_unreachable(scenarioMessages.c_str()); + } + + auto reduce_result = op->getResult(0); + auto shaped_ty = dyn_cast(reduce_result.getType()); + auto shape = shaped_ty.getShape(); + auto ops = llvm::map_to_vector(op.getBody()->without_terminator(), + [](Operation &op) { return &op; }); + // Support only the case: scalar = tl.load(1D tensor) + if (shape.size() != 1 || op.getAxis() != 0 || ops.size() != 1 || + !isa(ops.front())) { + llvm_unreachable(scenarioMessages.c_str()); + } + + auto castOp = rewriter.create( + loc, RankedTensorType::get(shape, rewriter.getIndexType()), + reduce_result); + auto offset = castOp.getResult(); + if (data.isEmpty()) { + data.getOffsetsRef().push_back(offset); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } else { + llvm_unreachable("parseReduce with offset already setup not yet supported"); + } +} + +template +void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + // FIXME: assume single result of operation + auto opRes = op->getResult(0); + auto opResTy = opRes.getType(); + std::vector resShape; + if (auto shapedResTy = dyn_cast(opResTy)) { + // For now, we consider this is UnstrucMemAcc because we have no other info. + // Visiting other ops may change the type due to more info. + data.setMemAccVal(MemAccVal::UnstrucMemAcc); + resShape = shapedResTy.getShape().vec(); + } else { + // scalar load means this is used as offset. It is StrucMemAcc. + data.setMemAccVal(MemAccVal::StrucMemAcc); + resShape.push_back(1); + } + for (auto &s : resShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(s)); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } + // set the source in BlockData so that we know an indirect-load op exists in + // the chain. + data.setSource(opRes); +} + +void BlockDataParser::rewriteAddPtr( + triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + auto insertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + BlockData data; + parseAddPtr(op, data, op.getLoc(), rewriter, known); + + if (data.getMemAccTypeRef().isUnstructured()) { + // TODO: Based on more info, try to create a performant IR + rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data); + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); + return; + } + + if (data.getSizesRef().size() == 0) { + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + data.getOffsetsRef().push_back(data.getScalarRef()); + } + + ArrayRef resultShape; + SmallVector staticShape1(1, 1); // sz 1 value 1 + if (auto shapedType = dyn_cast(op.getResult().getType())) { + resultShape = shapedType.getShape(); + } else { + assert(data.getRank() == 1); + resultShape = staticShape1; + } + + known[op.getResult()] = data; + + auto infered_size = 1; + for (int i = data.getSizesRef().size() - 1; i >= 0; i--) { + auto strideInt = makeIntAttr(data.getStridesRef()[i]); + auto sizeInt = makeIntAttr(data.getSizesRef()[i]); + assert(sizeInt); + if (sizeInt.value() == 1 && strideInt && strideInt.value() == 0) { + data.getStridesRef()[i] = rewriter.getIndexAttr(infered_size); + } + infered_size *= sizeInt.value(); + } + + if (data.hasResElemTy()) { + auto memrefType = dyn_cast(data.getSourceRef().getType()) + .cloneWith(std::nullopt, data.getResElemTyRef()); + UnrealizedConversionCastOp castOp = + rewriter.create( + op.getLoc(), memrefType, data.getSourceRef()); + data.setSource(castOp.getOutputs()[0]); + } + + // no module handle + memref::ReinterpretCastOp castOp = + data.createCastOp(resultShape, op.getLoc(), rewriter); + Value src = castOp.getResult(); + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + data.setSource(src); + rewriter.replaceOp(op, src); + rewriter.restoreInsertionPoint(insertPoint); +} + +void BlockDataParser::rewriteAdvanceOp( + triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + BlockData blockData; + parse(op.getOperand(0), blockData, loc, rewriter, known); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (const auto [increment, offset, stride] : + llvm::zip(incrementOffsets, blockData.getOffsetsRef(), + blockData.getStridesRef())) { + Value offsetValue; + if (auto offsetIntAttr = makeIntAttr(offset)) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + offsetValue = constOp.getResult(); + } else { + offsetValue = offset.get(); + } + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), increment); + auto mulOp = rewriter.create(loc, castOp.getResult(), + stride.get()); + auto addOp = + rewriter.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + blockData.getOffsetsRef().clear(); + + for (auto offset : newOffsets) { + blockData.getOffsetsRef().push_back(offset); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(blockData.getRank() == 1); + } + + auto newOp = blockData.createCastOp(resultShape, loc, rewriter); + + rewriter.replaceOp(op, newOp.getResult()); + + known[newOp.getResult()] = blockData; +} + +void BlockDataParser::rewriteYieldOp( + scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &known) { + // any inserted instruction should be before this yield + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + + SmallVector initArgState; + SmallVector operands(adaptor.getOperands()); + // Track the second chunks of modulo pointers so that we can append them to + // the yield results + SmallVector moduloSecondChunks; + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // BlockData for those values. + for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { + if (auto mappedV = rewriter.getRemappedValue(v)) { + // If this value is a tensor of pointers produced by AddPtrOp, + // we should have already converted to a ReinterpretCastOp without + // layout information for the normal cases, or to an + // UnrealizedConversionCastOp for the split pointer case. + if (v.getDefiningOp() || + v.getDefiningOp() || + v.getDefiningOp()) { + if (auto castOp = mappedV.getDefiningOp()) { + assertLegalUnrealizedCast(castOp); + auto castInputs = castOp.getInputs(); + v = castOp.getResult(0); + operands[i] = castInputs[0]; + moduloSecondChunks.push_back(castInputs[1]); + } else if (auto castOp = + mappedV.getDefiningOp()) { + v = castOp; + } else { + llvm_unreachable("mapped value defined by an unexpected op"); + } + } else { + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. + + // TODO: + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported + if (isa(mappedV.getType()) && + isa( + dyn_cast(mappedV.getType()).getElementType())) + llvm_unreachable("unsupported scenario where a value is a tensor of " + "pointers but not produced by AddPtrOp"); + v = mappedV; + } + } + + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) + continue; + auto thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto reintCastOp = v.getDefiningOp(); + auto unrealizedCastOp = v.getDefiningOp(); + + // assert condition deleted: (unrealizedCastOp && + // unrealizedCastOp->hasAttr(ModuloState::WraparoundAttr)) + assert( + reintCastOp || + (isa(v.getType()) && + isa(dyn_cast(v.getType()).getElementType()))); + + BlockData state; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known); + } else if (unrealizedCastOp) { + assertLegalUnrealizedCast(unrealizedCastOp); + parseUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + known); + } else { + parse(v, state, op.getLoc(), rewriter, known); + } + initArgState.push_back(state); + } + + // For each of the BlockData recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.getOffsetsRef()) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. + if (auto sIntAttr = makeIntAttr(s)) { + assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(s.get()); + } + } + + for (auto s : state.getStridesRef()) { + assert(!makeIntAttr(s) && "BlockData strides for yield within for " + "loop not expected to be " + "attribute."); + operands.push_back(s.get()); + } + } + + for (auto chunk : moduloSecondChunks) { + operands.push_back(chunk); + } + + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +namespace { + +struct ModuloChunkInitArg { + Value reinterpretCast = nullptr; + // where in the init args is the first chunk placed + size_t initArgIndex = -1; +}; + +} // namespace + +void BlockDataParser::rewriteForOp( + scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &known) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + // If we have a load op that uses a modulo pointer, we need to insert both of + // the memref chunks to the init args. We reuse the sizes from the original + // memrefs. This data structure keeps track of where these additional init + // args should be inserted. + // + // As an example, if we have a 2D memrefs being split, we first put the first + // chunk in the order as it appears. Then, once all of the original init args + // are processed, we insert their offsets and strides, and finally the second + // chunk. + SmallVector, BlockData>, + 6> + moduloStates; + + // Amongst the init args, track the indices that map to the first chunk of a + // modulo pair. This is used to distinguish between the normal + // reinterpret_casts whose return types need to be rewritten to match what the + // for loop is yielding. + DenseSet moduloInitArgIndices; + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = rewriter.getRemappedValue(arg); + memref::ReinterpretCastOp reintCastOp; + UnrealizedConversionCastOp unrealizedCastOp; + + // If this init arg is supposed to be remapped, use the remapped + // value instead. In addition, if this init arg is a memref created + // by a reinterpret_cast or a tensor of index, there is a chance that + // it will be used in addptr. Create BlockData for each such init arg. + if (mappedV) { + // TODO: + // Passing a block argument pointer directly into a for loop not + // supported. + assert(!(dyn_cast(mappedV) && + isa(mappedV.getType())) && + "cannot take pointer block argument as init arg for for loop"); + if (auto op = mappedV.getDefiningOp()) { + reintCastOp = op; + newInitArgs.push_back(mappedV); + } else if (auto op = + mappedV.getDefiningOp()) { + assertLegalUnrealizedCast(op); + unrealizedCastOp = op; + auto inputs = unrealizedCastOp.getInputs(); + + SmallVector initArgData{ + ModuloChunkInitArg{inputs[0], i}, + ModuloChunkInitArg{inputs[1]}, + }; + + moduloInitArgIndices.insert(i); + moduloStates.push_back( + std::make_tuple(unrealizedCastOp, initArgData, BlockData{})); + + newInitArgs.push_back(inputs[0]); + } else { + newInitArgs.push_back(mappedV); + } + + } else { + newInitArgs.push_back(arg); + } + + auto indexTensor = + isa(arg.getType()) && + isa(dyn_cast(arg.getType()).getElementType()); + + if (!unrealizedCastOp && !reintCastOp && !indexTensor) + continue; + + BlockData data; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } else if (unrealizedCastOp) { + parseUnrealizedCast(unrealizedCastOp, data, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + std::get<2>(moduloStates.back()) = data; + } else { + parse(arg, data, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } + + // Record the BlockData for later processing + initArgIndexState.push_back(std::make_pair(i, data)); + } + + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the BlockData recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, data] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(data.getOffsetsRef())) { + auto sIntAttr = makeIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + data.getOffsetsRef()[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + + for (auto [j, s] : llvm::enumerate(data.getStridesRef())) { + auto sIntAttr = makeIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + data.getStridesRef()[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the blockdata we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, data)); + levelToBlockArgIndex[level].insert(i); + + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 + // * s1)>> + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + // + // For init args that are the first chunk of a modulo pair, there is + // no need for the type to be rewritten because the strides and + // offsets are already dynamic. + if (!moduloInitArgIndices.contains(i) && + newInitArgs[i].getDefiningOp()) { + SmallVector resultShape; + for (auto s : data.getSizesRef()) { + auto sIntAttr = makeIntAttr(s); + assert(sIntAttr && "expected constant size"); + resultShape.push_back(sIntAttr.value()); + } + auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); + + LLVM_DEBUG({ + llvm::dbgs() << "new reinterpret_cast with dynamic sizes " + "and offsets:"; + castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + newInitArgs[i] = castOp.getResult(); + } + } + + // Pass in the second chunk of each modulo pair + for (auto &[unrealizedCastOp, chunkData, data] : moduloStates) { + chunkData[1].initArgIndex = newInitArgs.size(); + newInitArgs.push_back(chunkData[1].reinterpretCast); + } + + rewriter.restoreInsertionPoint(origIp); + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = rewriter.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping mapping; + mapping.map(op.getInductionVar(), iv); + mapping.map(op.getInitArgs(), newInitArgs); + mapping.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, mapping); + } + + // Load op is lowered independent of the pointer, if we have a split + // pointer due to modulo, we need to "logically combine" these two + // memrefs into a single one using unrealized_cast_op. This way, when + // lowering the load, the pattern can detect if additional copies are + // inserted. When we are in a loop, it is more complicated because we + // have to insert a new unrealized_cast_op that combines the two memrefs + // in the init arg list. In addition, because init args hold no offset + // and size information, we have to manually insert two additional + // reinterpret_cast ops as input to this unrealized_cast_op so that the + // load have enough information to generate the corresponding copy. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(b.getBlock()); + + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + for (auto &[unrealizedCastOp, chunkData, data] : moduloStates) { + SmallVector newReinterpretCasts; + for (auto &chunk : chunkData) { + newReinterpretCasts.push_back(args[chunk.initArgIndex]); + } + + auto combinedCast = b.create( + loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, + unrealizedCastOp->getAttrs()); + + args[chunkData[0].initArgIndex].replaceUsesWithIf( + combinedCast.getResult(0), [](OpOperand &operand) { + assert(!isa(operand.getOwner()) && + "Storing to split pointers not supported"); + return isa(operand.getOwner()); + }); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's BlockData fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + for (auto [i, data] : knownPtrsTmp) { + for (auto it = data.getOffsetsRef().begin(); + it != data.getOffsetsRef().end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = data.getStridesRef().begin(); + it != data.getStridesRef().end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + auto key = newOp.getRegionIterArgs()[i]; + known.insert(std::make_pair(key, data)); + } + assert(static_cast(cnt + moduloStates.size()) == + newOp.getRegionIterArgs().size() && + "expect to remap all new block args"); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + rewriter.replaceOp(op, resultsToReplaceWith); + + // Update the loop body. Manually invoke the rewrite logic on addptr and yield + // in the loop body, so we can take advantage of the states we built up + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto addptrOp = dyn_cast(bodyOp)) { + // FIXME: Constructed adaptor here does not hold the transformed op info. + auto adaptor = triton::AddPtrOp::Adaptor(addptrOp); + rewriteAddPtr(addptrOp, adaptor, rewriter, known); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, known); + } else if (auto forOp = dyn_cast(bodyOp)) { + // TODO: + // Nested for loops are not supported at the moment + assert(0 && "nested loops currently not supported"); + // rewriteForOp(forOp, rewriter, levelToBlockArgIndex, level+1, + // knownPtrs); levelToBlockArgIndex.erase(level+1); + } + } + + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + rewriteYieldOp(yieldOp, rewriter, levelToBlockArgIndex, level, known); + } + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +/// @brief Rewrite the triton::AddPtrOp to handle unstructured memory access. +/// @param op The triton::AddPtrOp to be rewritten. +/// @param adaptor The adaptor of the triton::AddPtrOp, used to get operands. +/// @param rewriter The pattern rewriter used to modify the IR. +/// @param data The BlockData containing information about the memory access. +void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( + triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, BlockData &data) { + auto loc = op.getLoc(); + auto &offsets = data.getOffsetsRef(); + auto &blockSizes = data.getSizesRef(); + auto &strides = data.getStridesRef(); + Value ptrOffset = adaptor.getOffset(); + Value zeroIdx = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIdx = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto addptrRes = op.getResult(); + assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); + auto loadOp = *(addptrRes.user_begin()); + + // Prepare empty tensor for loop based scalar load + // FIXME: We use cast here because addptr must return tensor>. + // True? + auto resTy = cast(addptrRes.getType()); + auto resEPtrTy = resTy.getElementType(); + auto resETy = cast(resEPtrTy).getPointeeType(); + Value loaded = rewriter.create(loc, blockSizes, resETy); + SmallVector initArgs; + initArgs.push_back(loaded); + + SmallVector forLBs; + SmallVector forUBs; + SmallVector forSteps; + for (auto &s : offsets) { + forLBs.push_back(zeroIdx); + } + for (auto &s : blockSizes) { + forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); + } + for (auto &s : strides) { + forSteps.push_back(oneIdx); + } + SmallVector ivs; + OpBuilder builder(op); + auto loop = createNestedLoops( + builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, + initArgs, + [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, + ValueRange iterArgs) { + OpBuilder::InsertionGuard g(bB); + bB.setInsertionPointToStart(bB.getBlock()); + + Value scalarOffsetRaw = + bB.create(bLoc, ptrOffset, allIVs); + Value scalarOffset = bB.create( + bLoc, bB.getIndexType(), scalarOffsetRaw); + // Replace offset & size. Only single element. + data.getOffsetsRef().clear(); + data.getOffsetsRef().push_back(scalarOffset); + data.getSizesRef().clear(); + data.getSizesRef().push_back(bB.getIndexAttr(1)); + data.getStridesRef().clear(); + data.getStridesRef().push_back(bB.getIndexAttr(1)); + memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); + rewriter.replaceOp(op, castOp); + // Move tt.load using this tt.addptr into this block + loadOp->moveAfter(castOp); + loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); + bB.create(bLoc, iterArgs); + }); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..75f5ad897 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,29 @@ +add_triton_library(TritonToLinalg + TritonToLinalgPass.cpp + LoadStoreConverter.cpp + FunctionConverter.cpp + ArgMinMaxConverter.cpp + TritonOpConverter.cpp + BlockPtrAnalysis.cpp + MaskAnalysis.cpp + UseAnalysis.cpp + + DEPENDS + TritonToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonAnalysis + MLIRTritonNPUUtils + MLIRSCFTransforms + MLIRLinalgTransforms +) diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp new file mode 100644 index 000000000..af58b6dbe --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/FunctionConverter.h" + +namespace FunctionConverter { +using namespace mlir; +using namespace triton; + +LogicalResult GetProgramIDConverter::matchAndRewrite( + triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto axis = (uint32_t)op.getAxis(); + assert(axis < GetProgramIDConverter::LAUNCH_GRID_RANK && + "Invalid axis for GetProgramIdOp"); + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - GetProgramIDConverter::LAUNCH_GRID_RANK + + axis); + rewriter.replaceOp(op, id); + return success(); +} + +LogicalResult GetNumProgramsConverter::matchAndRewrite( + triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto axis = (uint32_t)op.getAxis(); + assert(axis < GetNumProgramsConverter::LAUNCH_GRID_RANK && + "Invalid axis for GetNumProgramsOp"); + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument( + numArgs - GetNumProgramsConverter::LAUNCH_GRID_RANK * 2 + axis); + rewriter.replaceOp(op, id); + return success(); +} +} // namespace FunctionConverter diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp new file mode 100644 index 000000000..2d49a6ab8 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -0,0 +1,752 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/LoadStoreConverter.h" +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include "llvm/Support/Debug.h" + +#include +#include +#include + +#define DEBUG_TYPE "triton-load-store-converter" + +namespace LoadStoreConverter { +using namespace mlir; +using namespace triton; + +LogicalResult +AddPtrConverter::matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::rewriteAddPtr(op, adaptor, rewriter, known); + return success(); +} + +LogicalResult LoadConverter::toTensorAndReplace( + triton::LoadOp &op, RankedTensorType &tensorType, memref::AllocOp &allocOp, + const Location &loc, ConversionPatternRewriter &rewriter) const { + Value loadedTensor = rewriter.create( + loc, tensorType, allocOp, true, true); + rewriter.replaceOp(op, loadedTensor); + return success(); +} + +/// @brief Check whether the triton::LoadOp has been modified to the specified +/// state by the AddPtrConverter. +/// @param op The triton::LoadOp operation to be checked. +/// @return Return success if the operation conforms to the specified state; +/// otherwise, return failure. +LogicalResult +LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const { + if (!isa(op->getParentOp())) { + return failure(); + } + if (!op->hasAttr("IndirectLoad")) { + return failure(); + } + auto ptrOp = op.getPtr().getDefiningOp(); + auto ptrBlock = ptrOp->getBlock(); + auto opBlock = op->getBlock(); + if (ptrBlock == opBlock) { + return failure(); + } + + return success(); +} + +/// @brief Continue to modify the triton::LoadOp from the state modified by the +/// AddPtrConverter. +/// @param op The triton::LoadOp operation to be processed. +/// @param adaptor The adaptor for the operation, used to obtain operands. +/// @param rewriter The pattern rewriter used to rewrite the operation. +/// @return Return success if the operation is successful; otherwise, return +/// failure. +LogicalResult LoadConverter::continueModifyFromAddPtrConverter( + triton::LoadOp &op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto forOp = op->getParentOfType(); + Operation *firstOp = &forOp.getBody()->front(); + auto extractOp = cast(firstOp); + auto ivs = extractOp.getIndices(); + // Single iterArg which is inserted by AddPtrConverter. + auto iterArg = forOp.getRegionIterArg(0); + auto ptr = adaptor.getPtr(); + + rewriter.setInsertionPointAfter(op); + Value castVal = ptr.getDefiningOp(); + Value idxZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value loadVal = + rewriter.create(loc, castVal, ValueRange{idxZero}); + Value insertedVal = + rewriter.create(loc, loadVal, iterArg, ValueRange{ivs}); + // a yield op is already created by AddPtrConverter. + // so we need to replace it with a new yield op. + Operation *terminator = forOp.getBody()->getTerminator(); + scf::YieldOp oldYieldOp = cast(terminator); + auto yieldOp = rewriter.create(loc, ValueRange{insertedVal}); + rewriter.replaceOp(oldYieldOp, yieldOp); + // Now the scf.for is complete, we can replace tt.load with it. + auto rank = cast(op.getResult().getType()).getShape().size(); + Operation *rootForOp = op; + while (rank != 0) { + rank--; + rootForOp = rootForOp->getParentOfType(); + } + rewriter.replaceOp(op, rootForOp); + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(rootForOp) << "\n"; }); + return success(); +} + +LoadConverter::LoadConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Check if tt.load is modified by AddPtrConverter to a specified state. + if (checkModifiedByAddPtrConverter(op).succeeded()) { + return continueModifyFromAddPtrConverter(op, adaptor, rewriter); + } + + auto ptr = adaptor.getPtr(); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + // handling scalar + if (!isa(op.getResult().getType())) { + auto scalarMemref = + BlockDataParser::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); + auto resTy = op.getResult().getType(); + auto idxZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto loadOp = rewriter.create(loc, resTy, scalarMemref, + idxZero.getResult()); + rewriter.replaceOp(op, loadOp.getResult()); + return success(); + } + + // handling no mask + auto memRefType = dyn_cast(ptr.getType()); + if (!memRefType) { + return rewriter.notifyMatchFailure( + op, "LoadOp expects a memref, not a memref of pointers"); + } + auto memRefShape = memRefType.getShape(); + auto memRefElementType = memRefType.getElementType(); + + auto allocOp = rewriter.create( + loc, MemRefType::get(memRefShape, memRefElementType)); + + auto tensorType = RankedTensorType::get(memRefShape, memRefElementType); + // boundary check + auto boundaryCheck = op.getBoundaryCheck(); + if (!boundaryCheck.empty()) { + auto boundarySizes = mlir::ConverterUtils::getBoundarySizes( + boundaryCheck, op.getPtr(), ptr, loc, rewriter); + // handle the padding + auto padding = op.getPadding(); + if (padding.has_value()) { + TypedAttr padAttr = rewriter.getZeroAttr(memRefElementType); + // triton already ensure only NAN and ZERO are passed in + if (padding.value() == triton::PaddingOption::PAD_NAN) { + // FIXME: Why NaN requires elemTy to be non-int or non-index? + assert(!memRefElementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(padAttr).getValue().getSemantics()); + padAttr = rewriter.getFloatAttr(memRefElementType, apNaN); + } + auto padVal = rewriter.create(loc, padAttr); + + auto shape = memRefType.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < boundarySizes.size(); i++) { + auto dim = boundaryCheck[i]; + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[dim])); + Value bndSizei = dyn_cast(boundarySizes[i]); + if (!bndSizei) { + bndSizei = rewriter.create( + loc, cast(boundarySizes[i].get())); + } + auto cmpOp = rewriter.create( + loc, arith::CmpIPredicate::slt, bndSizei, shapei); + accBase = rewriter.create(loc, accBase, cmpOp.getResult()) + .getResult(); + } + rewriter.create( + loc, accBase, [&](OpBuilder &builder, Location loc) { + builder.create(loc, ValueRange{padVal}, + ValueRange{allocOp}); + builder.create(loc); + }); + } + + auto srcSubView = + mlir::ConverterUtils::makeSubViewOp(ptr, boundarySizes, loc, rewriter); + auto dstSubview = mlir::ConverterUtils::makeSubViewOp( + allocOp, boundarySizes, loc, rewriter); + rewriter.create(loc, srcSubView, dstSubview); + + return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); + } + + if (!mask) { + assert(!other && "can not input 'other' when 'mask' is not set"); + if (auto unrealizedCastOp = + ptr.getDefiningOp()) { + // TODO : not support handle associate with "module" + // hint : can be handled in Linearize + } else { + // If last dimension stride equals 2, try deinterleave optimization. + auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); + if (ptrStrides.back() == 2 && (memRefShape.back() % 2 == 0) && + mlir::triton::DeinterleaveStatusOptimization(op, adaptor, rewriter) + .succeeded()) { + return success(); + } + rewriter.create(loc, ptr, allocOp); + } + + return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); + } + + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "can not lower uncontinuout masked loads"); + } + + if (other) { + auto scalarOther = + mlir::ConverterUtils::getScalarValue(other, loc, rewriter); + assert( + scalarOther && + "other value used in masked load produced by unsupported instruction!"); + auto shape = memRefType.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < memRefType.getShape().size(); i++) { + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + Value dimi = dyn_cast(mstate.dims[i]); + if (!dimi) { + dimi = rewriter.create( + loc, cast(mstate.dims[i].get())); + } + auto cmpOp = rewriter.create( + loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = rewriter.create(loc, accBase, cmpOp.getResult()) + .getResult(); + } + + rewriter.create( + loc, accBase, [&](OpBuilder &builder, Location loc) { + builder.create(loc, ValueRange{scalarOther}, + ValueRange{allocOp}); + builder.create(loc); + }); + } + + // To enable deinterleave optimization with mask load, mask state along last + // dimension couldn't be split, which means `dims.back()` must be equal to + // origin type last dimension constant size and `offsets.back()` must be 0. + // + // The basis is that last dimension range comparison would generate + // unaccepted discontinuous mask. + if (mstate.getRank() == memRefType.getRank() && + isConstantIntValue(mstate.offsets.back(), 0) && + isConstantIntValue(mstate.dims.back(), memRefType.getShape().back())) { + auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); + if (ptrStrides.back() == 2 && (memRefType.getShape().back() % 2 == 0) && + DeinterleaveStatusWithMaskOptimization(op, adaptor, rewriter, mstate, + allocOp) + .succeeded()) { + return success(); + } + } + + if (auto unrealizedCastOp = ptr.getDefiningOp()) { + // TODO : not support handle associate with "module" + // hint : can be handled in Linearize + } else { + memref::SubViewOp srcSubView = mstate.getSubview(ptr, loc, rewriter); + memref::SubViewOp dstSubView = mstate.getSubview(allocOp, loc, rewriter); + rewriter.create(loc, srcSubView, dstSubView); + } + return this->toTensorAndReplace(op, tensorType, allocOp, loc, rewriter); +} + +AtomicRMWConverter::AtomicRMWConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +// lowering tt.atomicRMW to linalg.generic +// If atomic op's return value is used by other op as it's the old value stored +// at the ptrwe will use tt.load to get it +// +// example: +// input: +// %return_value = tt.atomic_rmw fadd, acq_rel, gpu, +// %output_memref, %input_tensor, %mask : +// (tensor<256x!tt.ptr>, tensor<256xf32>, tensor<256xi1>) +// -> tensor<256xf32> +// +// output: +// memref.copy %output_memref, %ub_buf : memref to memref +// %17 = bufferization.to_tensor %alloc_3 restrict writable : memref<256xf32> +// linalg.generic +// {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} +// ins(%output_memref, %masked_input_memref : memref, memref) +// outs(%subview_2 : memref) +// attrs = {GenericAtomicRMW = "fadd", MemSemantic = "acq_rel", +// MemSyncScope = "gpu"} { +// ^bb0(%in: f32, %in_9: f32, %out: f32): +// %25 = arith.addf %in, %in_9 : f32 +// linalg.yield %25 : f32 +// } +LogicalResult +AtomicRMWConverter::matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // If the result of AtomicRMWOp is not used, we don't need to load the old + // data stored at the ptr + auto ptr = adaptor.getPtr(); + auto val = op.getVal(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + if (!resType) { + return rewriter.notifyMatchFailure( + op, "atomicRMWConverter: scalar will be handled by " + "ScalarAtomicRMWCanonicalizer"); + } + + auto rmwOp = op.getAtomicRmwOp(); + if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { + return rewriter.notifyMatchFailure( + op, "AtomicRMWConverter: unsupported atomic kind for now"); + } + + // 1. Simple case where no mask is used. + auto type = dyn_cast(ptr.getType()); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of + // operations. The workaround is to broadcast the pointers early in the + // address calculation. A proper fix is complicated, but at least we can + // provide a better error message. + return rewriter.notifyMatchFailure( + op, "AtomicRMWOp expects a memref, not a memref of pointers"); + } + + auto dstMemref = ptr; + // Well, linalg structure op wouldn't support mixed tensor/buffer semantics + // any more in latest LLVM(triton LLVM dependency has involed this), so we + // need to convert tensor to buffer early. + auto dstType = dstMemref.getType(); + Value inputMemref = + rewriter.create(loc, dstType, val); + + // 2. handle the mask for the atomic op + MaskState mstate; + auto mask = op.getMask(); + + // When the dsl do not pass the mask to this op like + // `tl.atomic_add(out_ptr0 + xindex, tmp2)`, it will create a constant mask + // for this op by default, which is not supported by maskAnalysis, so we + // need to handle this situation + // + // This logic come from semantic.py: + // + // if not mask: + // mask_ir = builder.get_int1(True) + // mask_ty = tl.int1 + // if ptr.type.is_block(): + // mask_ir = \ + // builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + // mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + // mask = tl.tensor(mask_ir, mask_ty) + // + // ... + // + // return ptr, val, mask + // + auto constantMask = mask.getDefiningOp(); + if (!constantMask) { + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "Cannot lower continuous masked loads"); + } + dstMemref = mstate.getSubview(ptr, loc, rewriter); + inputMemref = mstate.getSubview(inputMemref, loc, rewriter); + } else { + if (!isConstantMaskTrue(mask)) { + rewriter.eraseOp(op); + return success(); + } + } + + // 3. If needed, handle the return value of atomic op + // + // tt.atomicRMW op has two part of feature + // 1. load the old data at the ptr + // 2. atomically store the data on ub to the ptr + // at the same time it perform the action it has been assigned + // So we lower this op to load + atomically store + // + // The first part is not necessary when the returned value of atomic op + // is not used, it will be deleted cause it's meaningless + // Here, we preemptively determine whether it will be used + // and decide whether it is necessary to create the load process based on + // this assessment. + // + // logic of handling is copied + // TODO: decoupling the logic of load, put it in the Utils + if (!op.getResult().use_empty()) { + auto tensorType = + RankedTensorType::get(type.getShape(), type.getElementType()); + auto alloc = rewriter.create( + loc, MemRefType::get(type.getShape(), type.getElementType())); + + // For the return value, don't need to care about mask for now + // this op don't support other, so we best not fill it + rewriter.create(loc, ptr, alloc); + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + } + + // create element-wise map + int64_t rank = type.getRank(); + SmallVector inputDims; + auto context = rewriter.getContext(); + + for (int i = 0; i < rank; i++) { + inputDims.push_back(getAffineDimExpr(i, context)); + } + + SmallVector indexingMaps; + // As mask has been erased for now + // the number of input must be 2 + // the input memref is also the output memref + // Thus, there are a total of three inputs and outputs. + // so here we have 3 map to create + for (int i = 0; i < 3; i++) { + indexingMaps.push_back(AffineMap::get(rank, 0, inputDims, context)); + } + + auto linalgOp = rewriter.create( + loc, /* operands */ ValueRange{dstMemref, inputMemref}, + ValueRange{dstMemref}, indexingMaps, + mlir::ConverterUtils::getNParallelLoopsAttrs(rank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + Value opResult = createAtomicBinaryOps(nestedBuilder, nestedLoc, op, + type.getElementType(), + blockArgs[0], blockArgs[1]); + nestedBuilder.create(nestedLoc, opResult); + }); + + // "library_call" + // indicating the actual semantic of this op + // TODO: If the hardware support the MemSemantic/MemSyncScope + // We pass them down + // otherwise they need to be deleted + const StringRef genericAtomicRMW = "GenericAtomicRMW"; + const StringRef memSemantic = "MemSemantic"; + const StringRef memSyncScope = "MemSyncScope"; + linalgOp->setAttr(genericAtomicRMW, + rewriter.getStringAttr(stringifyEnum(op.getAtomicRmwOp()))); + linalgOp->setAttr(memSemantic, + rewriter.getStringAttr(stringifyEnum(op.getSem()))); + linalgOp->setAttr(memSyncScope, + rewriter.getStringAttr(stringifyEnum(op.getScope()))); + + // Mark atomic_and/or/xor specially which need software simulation in terms + // of backend restriction + if (softwareAtomicKinds.contains(op.getAtomicRmwOp())) + linalgOp->setAttr("Software", rewriter.getUnitAttr()); + + // if the result hasn't been replace by load + // we need to erase it here + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + } + return success(); +} + +LogicalResult +ScalarStoreCanonicalizer::matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const { + + if (!op.getValue().getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarStoreCanonicalizer handles scalar store scene!"); + } + + auto ptr = op.getPtr(); + auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); + auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); + auto valTy = RankedTensorType::get({(int64_t)1}, op.getValue().getType()); + auto valSplat = + rewriter.create(op.getLoc(), valTy, op.getValue()); + + auto newStoreOp = rewriter.create( + op.getLoc(), ptrSplat, valSplat, op.getCache(), op.getEvict()); + rewriter.replaceOp(op, newStoreOp); + return success(); +} + +LogicalResult +ScalarAtomicRMWCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + + if (!op.getVal().getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarAtomicRMWCanonicalizer handles scalar atomic rmw op scene!"); + } + + auto ptr = op.getPtr(); + auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); + auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); + auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType()); + auto valSplat = + rewriter.create(op.getLoc(), valTy, op.getVal()); + auto maskTy = RankedTensorType::get({(int64_t)1}, op.getMask().getType()); + auto maskSplat = + rewriter.create(op.getLoc(), maskTy, op.getMask()); + + auto newAtomicOp = rewriter.create( + op.getLoc(), valTy, op.getAtomicRmwOp(), ptrSplat, valSplat, maskSplat, + op.getSem(), op.getScope()); + rewriter.replaceOp(op, newAtomicOp); + return success(); +} + +// The atomic max op with float input will be devided into +// two atomic max ops with integer input +// One handles the part of the tensor greater than zero +// the other deals with the part less than zero +// It will lead to maskAnalysis failure +// So here we need to revert the procedures in semantics.py +// The triton IR is like +// +// %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x256xf32> +// %1 = tt.bitcast %value : tensor<1x256xf32> -> tensor<1x256xi32> +// %2 = tt.bitcast %ptr : tensor<1x256x!tt.ptr> -> +// tensor<1x256x!tt.ptr> %3 = arith.cmpf oge, %1, %cst_0 %4 = arith.cmpf +// olt, %1, %cst_0 %5 = arith.andi %8, %3 %6 = tt.atomic_rmw max, acq_rel, gpu, +// %2, %1, %5 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> +// tensor<1x256xi32> +// %7 = arith.andi %8, %4 +// %8 = tt.atomic_rmw umin, acq_rel, gpu, %2, %1, %7 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> +// tensor<1x256xi32> +// +// it's hard to handle and meaningless complicated for our device +// so we revert it to +// %0 = tt.atomic_rmw max, acq_rel, gpu, %23, %21, %8 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xf32>, tensor<1x256xi1>) -> +// tensor<1x256xf32> +LogicalResult +AtomicMaxMinCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + // Revert the op to its original form + auto ptrBitcastOp = op.getPtr().getDefiningOp(); + auto valueBitcastOp = op.getVal().getDefiningOp(); + if (!ptrBitcastOp || !valueBitcastOp) { + return failure(); + } + + // We only need to handle the op when the element type is float + auto elementType = + dyn_cast(valueBitcastOp.getSrc().getType()).getElementType(); + if (!isa(elementType)) { + return failure(); + } + + auto rmwOp = op.getAtomicRmwOp(); + // here we know that atomic UMAX/UMIN + // is created by special logic of triton right now + // so we can simply delete it + if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { + // if the return value of op is used, we can't simply erase it + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + return failure(); + } + + if (rmwOp != triton::RMWOp::MAX && rmwOp != triton::RMWOp::MIN) { + return failure(); + } + + // 1. Though semantic interpreter will generate full true tensor as original + // mask if atomicrmwOp don't have it, above float devision process will also + // generate positive and negative comparison mask, which will cause to fold + // true mask. + // 2. While if atomicrmwOp has original mask, there exists andiop between + // original mask and positive/negative comparison mask + // + // Here wanna extract original mask + Value originalMask = op.getMask(); + if (auto andOp = originalMask.getDefiningOp()) + // LHS is convention in semantic interpreter + originalMask = andOp.getLhs(); + else if (auto cmpOp = originalMask.getDefiningOp()) { + if (cmpOp.getPredicate() != mlir::arith::CmpFPredicate::OGE || + !matchPattern(cmpOp.getRhs(), + /*positive float zero matcher*/ m_PosZeroFloat())) + // Here recheck frontend interpreter generation in no manual mask state + return op->emitError("Illegal mask for atomicrmwOp of float type"); + // Restore original true mask + originalMask = rewriter.create( + op->getLoc(), + /*typed attr*/ DenseElementsAttr::get( + cast(originalMask.getType()), true)); + } else + return op->emitError("Illegal mask for atomicrmwOp of float type"); + + auto originAtomicOp = rewriter.create( + op.getLoc(), valueBitcastOp.getSrc().getType(), op.getAtomicRmwOp(), + ptrBitcastOp.getSrc(), valueBitcastOp.getSrc(), originalMask, op.getSem(), + op.getScope()); + + // if the return value of op is used + // we need to handle its usage + // In semantic.py, if the atomic Max/Min with float input is used + // It will use select + bitcast to get float value + // so here we need to revert it too + // + // For example: + // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask1 : + // (tensor<32x!tt.ptr>... %1 = tt.atomic_rmw umin, acq_rel, gpu, %gm, + // %input, %mask2 : (tensor<32x!tt.ptr>... %2 = arith.select + // %devidedMask, %0, %1 : tensor<32xi1>, tensor<32xi32> %3 = tt.bitcast %2 : + // tensor<32xi32> -> tensor<32xf32> tt.store %outputMemref, %3 : + // tensor<32x!tt.ptr> + // + // will be revert to: + // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask : + // (tensor<32x!tt.ptr>... tt.store %outputMemref, %0 : + // tensor<32x!tt.ptr> + // + if (!op.getResult().use_empty()) { + for (OpOperand &use : op->getUses()) { + auto selectOp = dyn_cast(use.getOwner()); + if (!selectOp) + continue; + + for (OpOperand &selectUse : selectOp->getUses()) { + if (auto bitcastOp = + dyn_cast(selectUse.getOwner())) { + bitcastOp.getResult().replaceAllUsesWith(originAtomicOp); + } + } + } + rewriter.replaceOp(op, originAtomicOp); + } else { + rewriter.eraseOp(op); + } + + return success(); +} + +StoreConverter::StoreConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +StoreConverter::matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // triton store op basic + auto mask = op.getMask(); + auto loc = op.getLoc(); + auto ptr = adaptor.getPtr(); + auto val = adaptor.getValue(); + + // 1. boundary size check + // auto boundaryCheck = op.getBoundaryCheck(); + // if (!boundaryCheck.empty()) { + // SmallVector sizes = getBoundarySizes( + // boundaryCheck, op.getPtr(), ptr, loc, rewriter); + + // auto srcSlice = getExtractSlice(val, sizes, loc, rewriter); + // auto dstSubview = getSubview(ptr, sizes, loc, rewriter); + // auto storeOp = + // rewriter.create( + // loc, srcSlice, dstSubview); + // storeOp.setWritable(true); + // rewriter.eraseOp(op); + // return success(); + // } + + // 2. Simple load with no mask + if (!mask) { + auto storeOp = rewriter.create( + loc, val, ptr); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); + } + + // 3. Continuous masked stores. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return failure(); + } + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); + auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); + auto dstSubview = mstate.getSubview(ptr, loc, rewriter); + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); +} +} // namespace LoadStoreConverter diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp new file mode 100644 index 000000000..946b781f6 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp @@ -0,0 +1,543 @@ +#include "TritonToLinalg/MaskAnalysis.h" +// #include "triton-shared/Analysis/opFoldResultutils.h" +#include "Utils/Utils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include + +#define DEBUG_TYPE "mask-analysis" + +namespace mlir { + +namespace triton { + +LogicalResult MaskState::parse(Value operand, const Location &loc, + OpBuilder &builder) { + if (isa(operand.getType())) { + return parseIntScalar(operand, loc, builder); + } + auto definingOp = operand.getDefiningOp(); + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> parse op\n" + << *definingOp << "\n[MaskState]<==\n"; + }); + return TypeSwitch(definingOp) + .Case( + [&](auto op) { return this->parseConstant(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAdd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAnd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseCmp(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseMakeRange(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseBroadcast(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseSplat(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseExpandDims(op, loc, builder); }) + .Case( + [&](auto op) { return this->parse(op.getIn(), loc, builder); }) + .Case( + [&](auto op) { return this->parseDiv(op, loc, builder); }) + .Default([&](Operation *op) { return failure(); }); +} + +// extractSlice +tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + const Location &loc, + OpBuilder &builder) const { + auto sourceRType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + + auto dstRType = tensor::ExtractSliceOp::inferResultType(sourceRType, offsets, + dims, strides); + return builder.create(loc, dstRType, source, offsets, + dims, strides); +} + +tensor::InsertSliceOp MaskState::getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + return builder.create(loc, source, dest, offsets, dims, + strides); +} + +memref::SubViewOp MaskState::getSubview(Value source, const Location &loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + return builder.create(loc, cast(dstType), + source, offsets, dims, strides); +} + +static memref::SubViewOp createSubview(Value src, const Location &loc, + OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return builder.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +std::pair +MaskState::getSideBySideSubviews(Value block1, Value block2, + const Location &loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = + minOpFoldResult(col1, subviewColFull, loc, builder); + OpFoldResult subviewCol2 = + subOpFoldResult(subviewColFull, subviewCol1, loc, builder); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sbv1 = createSubview(block1, loc, builder, offsets, + {subviewRowFull, subviewCol1}, strides); + auto sbv2 = createSubview(block2, loc, builder, offsets, + {subviewRowFull, subviewCol2}, strides); + return {sbv1, sbv2}; +} + +std::pair +MaskState::getStackedSubviews(Value block1, Value block2, const Location &loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = + minOpFoldResult(row1, subviewRowFull, loc, builder); + OpFoldResult subviewRow2 = + subOpFoldResult(subviewRowFull, subviewRow1, loc, builder); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sbv1 = createSubview(block1, loc, builder, offsets, + {subviewRow1, subviewColFull}, strides); + auto sbv2 = createSubview(block2, loc, builder, offsets, + {subviewRow2, subviewColFull}, strides); + return {sbv1, sbv2}; +} + +// addstatescalar +LogicalResult MaskState::addStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = addOpFoldResult(state.start, scalar, loc, builder); + end = addOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) { + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + } else { + return addStateScalar(lhsState, rhsState.scalar, loc, builder); + } +} + +LogicalResult MaskState::divStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = divOpFoldResult(state.start, scalar, loc, builder); + end = divOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::divStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (!lhsState.scalar && rhsState.scalar) { + if (isZeroIndex(rhsState.scalar)) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where rhs is zero constant in divide!"; + return failure(); + } + + return divStateScalar(lhsState, rhsState.scalar, loc, builder); + } + + InFlightDiagnostic diag = emitError(loc) + << "Supported scenario where only rhs is a scalar"; + return failure(); +} + +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsOffset = lhsState.offsets[i]; + auto rhsOffset = rhsState.offsets[i]; + auto newOffset = maxOpFoldResult(lhsOffset, rhsOffset, loc, builder); + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + auto lhsEnd = addOpFoldResult(lhsOffset, lhsDim, loc, builder); + auto rhsEnd = addOpFoldResult(rhsOffset, rhsDim, loc, builder); + auto newEnd = minOpFoldResult(lhsEnd, rhsEnd, loc, builder); + auto newDim = subOpFoldResult(newEnd, newOffset, loc, builder); + + offsets.push_back(newOffset); + dims.push_back(newDim); + } + return success(); +} + +// Helper func for MaskState::parse() +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "All elements must share a single integer constant value"); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto op = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + this->scalar = op.getValue(); + } else { + auto value = cast(constOp.getValue()).getInt(); + this->scalar = builder.getIndexAttr(value); + } + return success(); +} + +// parseIntScalar +LogicalResult MaskState::parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + Value castOp; + + if (scalar.getType().isInteger(1)) { + castOp = builder.create(loc, builder.getIndexType(), + scalar); + } else { + castOp = + builder.create(loc, builder.getIndexType(), scalar); + } + this->scalar = castOp; + return success(); +} + +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) { + return failure(); + } + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(divOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(divOp.getRhs(), loc, builder))) { + return failure(); + } + return this->divStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) { + return failure(); + } + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt && + cmpOp.getPredicate() != arith::CmpIPredicate::sge && + cmpOp.getPredicate() != arith::CmpIPredicate::eq) { + LLVM_DEBUG({ llvm::dbgs() << "Unsupported cmpi predicate\n"; }); + return failure(); + } + MaskState lhsState; + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) { + return failure(); + } + + if (!(!lhsState.scalar && rhsState.scalar)) { + cmpOp->emitRemark("[MaskState] Unsupported cmpi scenario"); + return failure(); + } + + int32_t cmpDim = -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto dimIntAttr = makeIntAttr(lhsState.dims[i]); + if (!dimIntAttr || dimIntAttr.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + this->offsets = lhsState.offsets; + this->dims = lhsState.dims; + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::slt: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); + + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::sge: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newStart = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newOffset = subOpFoldResult(newStart, lhsState.start, loc, builder); + auto newDim = subOpFoldResult(lhsState.end, newStart, loc, builder); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::eq: { + auto newOffset = + subOpFoldResult(rhsState.scalar, lhsState.start, loc, builder); + auto newDim = builder.getIndexAttr(1); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + default: + return failure(); + } + return success(); +} + +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitError(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + this->offsets.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) { + return failure(); + } + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) { + this->dims[i] = builder.getIndexAttr(dstShape[i]); + } else { + llvm_unreachable("unexpected dimensions used in broadcast"); + } + } + return success(); +} + +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, + const Location &loc, OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (!isa(src.getType())) { + InFlightDiagnostic diag = + emitError(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + auto splatAsMask = [&](Operation *userOp) -> bool { + return TypeSwitch(userOp) + .Case([&](arith::AndIOp andOp) { return true; }) + .Case([&](arith::SelectOp selectOp) { + return selectOp.getCondition() == dst; + }) + .Case( + [&](triton::LoadOp loadOp) { return loadOp.getMask() == dst; }) + .Case( + [&](triton::StoreOp storeOp) { return storeOp.getMask() == dst; }) + .Default([&](Operation *op) { return false; }); + }; + + if (src.getType().isInteger(1) && !splatOp->use_empty() && + llvm::all_of(splatOp->getUsers(), splatAsMask)) { + for (auto s : dstShape) { + auto currentDim = + mulOpFoldResult(builder.getIndexAttr(s), this->scalar, loc, builder); + this->dims.push_back(currentDim); + this->offsets.push_back(builder.getIndexAttr(0)); + } + + this->scalar = nullptr; + return success(); + } + + for (auto s : dstShape) { + this->dims.push_back(builder.getIndexAttr(s)); + this->offsets.push_back(builder.getIndexAttr(0)); + } + return success(); +} + +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) { + return failure(); + } + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "Expect changed dimention to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + this->offsets.insert(this->offsets.begin() + axis, builder.getIndexAttr(0)); + + return success(); +} + +void MaskState::eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter) { + auto moduleOp = rawOp->getParentOfType(); + SmallVector worklist; + moduleOp->walk([&](Operation *op) { + if (isOpTriviallyDead(op)) + worklist.push_back(op); + }); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!isOpTriviallyDead(op)) + continue; + for (Value value : op->getOperands()) { + if (auto defOp = value.getDefiningOp()) + worklist.push_back(defOp); + } + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> inserted op: \n" + << *op << "\n[MaskState]<== is removed\n"; + }); + rewriter.eraseOp(op); + } +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp new file mode 100644 index 000000000..ae05f213d --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp @@ -0,0 +1,1149 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/TritonOpConverter.h" +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/ValueRange.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +LogicalResult +AssertConverter::matchAndRewrite(triton::AssertOp op, + PatternRewriter &rewriter) const { + // TODO: update assert converter to support llvm20 + LLVM_DEBUG(llvm::dbgs() + << "we do not support assertion in kernel in llvm-20 yet \n"); + rewriter.eraseOp(op); + return success(); +} + +LogicalResult +BitcastConverter::matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto arithBitcast = rewriter.create( + op.getLoc(), op.getType(), op.getOperand()); + rewriter.replaceOp(op, arithBitcast.getResult()); + return success(); +} + +LogicalResult +TransposeConverter::matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcRank = cast(src.getType()).getRank(); + auto res = ConverterUtils::getTransposedValue(src, op.getLoc(), rewriter, + op.getOrder()); + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult +YieldConverter::matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); +} + +LogicalResult +LoopConverter::matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::IndexMapSet + levelToBlockArgIndex; // level -> set of block arg index to be replaced + + BlockDataParser::rewriteForOp(op, rewriter, levelToBlockArgIndex, 0, known); + return success(); +} + +LogicalResult +AdvanceConverter::matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::rewriteAdvanceOp(op, rewriter, known); + return success(); +} + +void MakeTensorPtrConverter::populateVectorAsIndex( + SmallVector &vec, Operation::operand_range ops, + ConversionPatternRewriter &rewriter, Location loc) const { + for (auto opnd : ops) { + if (isa(opnd.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), opnd); + vec.push_back(castOp.getResult()); + } else { + assert(isa(opnd.getType())); + vec.push_back(opnd); + } + } +} + +OpFoldResult MakeTensorPtrConverter::accumulatePotentialOffsetOnBase( + triton::MakeTensorPtrOp op, Value base, OpFoldResult offset, + ConversionPatternRewriter &rewriter) const { + if (auto baseRecast = base.getDefiningOp()) { + assert(isa(op.getBase().getDefiningOp()) && + "base of MakeTensorPtrOp only comes from native ptr or AddPtrOp"); + + return addOpFoldResult(offset, baseRecast.getConstifiedMixedOffset(), + op.getLoc(), rewriter); + } + + return offset; +} + +// Design for load/store boundary_check. +memref::ReinterpretCastOp +MakeTensorPtrConverter::createRedundantOp(triton::MakeTensorPtrOp op, + ConversionPatternRewriter &rewriter, + BlockData &data) const { + auto loc = op.getLoc(); + // to do boundary_check in tt.load, we need to keep the parent tensor's + // shape info in the IR. + // use the parent tensor's shape to create a cast + auto resultSizes = data.getSizes(); + data.getSizesRef().clear(); + populateVectorAsIndex(data.getSizesRef(), op.getShape(), rewriter, loc); + SmallVector staticShapes; + SmallVector dynamicShapes; + dispatchIndexOpFoldResults(data.getSizesRef(), dynamicShapes, staticShapes); + auto castOp = data.createCastOp(staticShapes, loc, rewriter); + // restore sizes + data.getSizesRef().clear(); + for (auto &s : resultSizes) { + data.getSizesRef().push_back(s); + } + return castOp; +} + +// ToDo: +// 1. Refactor MakeTensorPtrConverter and AdvanceConverter with +// memref::ReinterpretCastOp and memref::SubViewOp. +// Use recast to describe full shape of tensor, and use subview to represent +// current block tensor. +// 2. Support boundary_check & padding_option for load/store, while current +// method with redundant recast is just enabled in load and drops padding_option +LogicalResult MakeTensorPtrConverter::matchAndRewrite( + triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + BlockData data; + + auto orderSize = op.getOrder().size(); + if (orderSize > 1) { + // Declaration of llvm::ArrayRef::slice(n, m) + // - Chop off the first N elements of the array, and keep M elements + // in the array. + // Take care that 'm' means chunk length + for (auto [first, second] : + llvm::zip(op.getOrder().slice(0, orderSize - 1), + op.getOrder().slice(1, orderSize - 1))) { + if (first != second + 1) { + op->emitError("Currently only support default order on block pointers"); + return failure(); + } + } + } + + // Handle base is defined by tt.bitcast + llvm::SmallDenseMap known; + BlockDataParser::parse(op.getBase(), data, loc, rewriter, known); + if (data.hasResElemTy()) { + auto memrefType = dyn_cast(data.getSourceRef().getType()) + .cloneWith(std::nullopt, data.getResElemTyRef()); + UnrealizedConversionCastOp castOp = + rewriter.create(loc, memrefType, + data.getSourceRef()); + data.setSource(castOp.getOutputs()[0]); + } else { + data.setSource(rewriter.getRemappedValue(op.getBase())); + } + + populateVectorAsIndex(data.getOffsetsRef(), op.getOffsets(), rewriter, loc); + populateVectorAsIndex(data.getStridesRef(), op.getStrides(), rewriter, loc); + + SmallVector newOffsets; + for (auto [offset, stride] : + llvm::zip(data.getOffsetsRef(), data.getStridesRef())) + newOffsets.push_back(mulOpFoldResult(offset, stride, loc, rewriter)); + + // 1. Consider that current base ptr may comes from `triton::AddPtrOp`, + // which have been converted to `memref::ReinterpretCastOp` with 1D + // shape([1,]) by `AddPtrConverter`. + // 2. While here would also convert `triton::MakeTensorPtrOp` to + // `memref::ReinterpretCastOp`, it will create use-def on double recast + // which means offset&size&stride info of first one will be dropped in terms + // of memref recast op specification. + // + // Conclusion with above two: + // Base of MakeTensorPtrOp has been seen as origin base, so it should + // reserve offset of first recast if it exists. + // Here extract the offset of first recastr and add it to highest dimension + newOffsets.front() = accumulatePotentialOffsetOnBase( + op, adaptor.getBase(), newOffsets.front(), rewriter); + + data.getOffsetsRef().clear(); + + for (auto offset : newOffsets) { + data.getOffsetsRef().push_back(offset); + } + + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + for (auto dim_size : resultShape) { + data.getSizesRef().push_back( + IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); + } + } else { + // scalar pointer, should produce a one dimensional memref + SmallVector scalarShape(1, 1); + resultShape = scalarShape; + assert(data.getRank() == 1); + } + + // special handling for davinci + // create redundant reinterpret_cast op for record shape info + auto redundantOp = createRedundantOp(op, rewriter, data); + redundantOp->setAttr("tensor_ptr_attr", rewriter.getStringAttr("shape")); + + // create reinterpret_cast op for the target block + data.setSource(redundantOp.getResult()); + auto castOp = data.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, castOp.getResult()); + return success(); +} + +LogicalResult PreciseDivConverter::matchAndRewrite( + triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getX(); + Value opb = op.getY(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + auto divOp = rewriter.create(loc, resType, opa, opb); + + rewriter.replaceOp(op, divOp); + return success(); +} + +/* + * Rewrite arith.select with contiguouse mask to + * tensor.extract_slice/insert_slice. + */ + +LogicalResult +SelectConverter::matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + // 0. Shortcut for scalars + auto type = dyn_cast(op.getResult().getType()); + if (!type) { + // do nothing non-tensor select + return failure(); + } + auto mask = op.getCondition(); + if (!isa(mask.getType())) { + // do nothing for scalar mask + return failure(); + } + + // 1. Check for continuous masked loads. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + mstate.eraseInsertedOps(op, rewriter); + return rewriter.notifyMatchFailure( + op, "Cannot lower continuous masked selects"); + } + + // 2. Slice out the masked part of true tensor + auto trueTensor = op.getTrueValue(); + auto trueSlice = mstate.getExtractSlice(trueTensor, loc, rewriter); + + // 3. Insert out the sliced true tensor into false tensor + auto falseTensor = op.getFalseValue(); + auto result = mstate.getInsertSlice(trueSlice, falseTensor, loc, rewriter); + + rewriter.replaceOp(op, result); + return success(); +} + +/* + * Move tt.bitcast to a previous location if tt.bitcast is not directly applied + * on function arguments + */ +LogicalResult +BitcastCanonicalizer::matchAndRewrite(triton::BitcastOp bitcastOp, + PatternRewriter &rewriter) const { + Value castSrc = bitcastOp.getSrc(); + Value castRes = bitcastOp.getResult(); + Type castSrcTy = castSrc.getType(); + Type castSrcPtrTy = isa(castSrcTy) + ? cast(castSrcTy).getElementType() + : castSrcTy; + if (!isa(castSrcPtrTy)) + return failure(); + + auto origBitwidth = getPointeeBitWidth(castSrc.getType()); + auto castBitwidth = getPointeeBitWidth(castRes.getType()); + + if (origBitwidth == 1) + origBitwidth = 8; + if (castBitwidth == 1) + castBitwidth = 8; + if (origBitwidth != castBitwidth) { + bitcastOp.emitError() << "Casting pointers with unmatched bitwidth!\n"; + return failure(); + } + + Operation *beforeCastOp = castSrc.getDefiningOp(); + if (beforeCastOp == nullptr) { + return failure(); + } + + auto newRes = + TypeSwitch>(beforeCastOp) + // before: addptr - bitcast - load/store + // after: bitcast - addptr - load/store + .Case([&](triton::AddPtrOp addptrOp) { + auto newCastOp = rewriter.create( + bitcastOp.getLoc(), castRes.getType(), addptrOp.getPtr()); + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), newCastOp.getResult(), + addptrOp.getOffset()); + }) + .Case([&](triton::SplatOp splatOp) { + Type newCastSrcTy = + cast(castRes.getType()).getElementType(); + + Value splatSrc = splatOp.getSrc(); + Type splatSrcTy = splatSrc.getType(); + if (auto splatSrcTensorTy = dyn_cast(splatSrcTy)) + newCastSrcTy = + splatSrcTensorTy.cloneWith(std::nullopt, newCastSrcTy); + auto newCastOp = rewriter.create( + bitcastOp.getLoc(), newCastSrcTy, splatSrc); + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), newCastOp); + }) + // before: bitcast - bitcast + // after(fusion optimization): bitcast + .Case([&](triton::BitcastOp prevCastOp) { + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), prevCastOp.getSrc()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(bitcastOp, + "Unknown bitcast pattern"); + }); + if (succeeded(newRes)) { + rewriter.replaceOp(bitcastOp, newRes.value()); + if (beforeCastOp->use_empty()) { + rewriter.eraseOp(beforeCastOp); + } + return success(); + } + return failure(); +} + +LogicalResult DenseConstantConverter::matchAndRewrite( + arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto denseAttr = cast(op.getValue()); + auto loc = op.getLoc(); + auto constSplatOp = arith::ConstantOp::materialize( + rewriter, denseAttr.getSplatValue(), + denseAttr.getElementType(), loc); + auto emptyOp = rewriter.create( + loc, cast(op.getResult().getType()).getShape(), + denseAttr.getElementType()); + + rewriter.replaceOpWithNewOp(op, ValueRange{constSplatOp}, + ValueRange{emptyOp}); + + return success(); +} + +LogicalResult +MakeRangeConverter::matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto type = cast(op.getResult().getType()); + auto shape = type.getShape(); + auto elementType = type.getElementType(); + auto context = op.getContext(); + + assert(type.getShape().size() == 1 && + isa(type.getElementType()) && + type.getElementType().getIntOrFloatBitWidth() == 32 && + "make range can only return 1D int32 tensor"); + + SmallVector indexingMaps{AffineMap::get( + /* dimCount */ 1, /* symbolCount */ 0, + {mlir::getAffineDimExpr(0, context)}, context)}; + + auto init = rewriter.create(loc, shape, elementType); + + auto nestedBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value index = nestedBuilder.create(loc, 0); + Value res = nestedBuilder.create( + loc, type.getElementType(), index); + nestedBuilder.create(loc, res); + }; + + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), /* operands */ ValueRange{}, ValueRange{init}, + indexingMaps, ConverterUtils::getNParallelLoopsAttrs(1), nestedBody); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); +} + +LogicalResult +SplatConverter::matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto init = rewriter.create(loc, op.getType().getShape(), + op.getType().getElementType()); + rewriter.replaceOpWithNewOp(op, ValueRange{adaptor.getSrc()}, + ValueRange{init}); + return success(); +} + +LogicalResult +ReshapeConverter::matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto src = op.getSrc(); + auto dst = op.getResult(); + Value shape = rewriter.create( + loc, + rewriter.getI64TensorAttr(cast(dst.getType()).getShape())); + auto reshapeOp = + rewriter.create(loc, dst.getType(), src, shape); + rewriter.replaceOp(op, reshapeOp.getResult()); + return success(); +} + +LogicalResult ExpandDimsConverter::matchAndRewrite( + triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto src = op.getSrc(); + auto resShape = cast(op.getResult().getType()).getShape(); + auto axis = op.getAxis(); + + SmallVector reassociation; + + auto src_last_dim = resShape.size() - 2; + auto map_func = [&](unsigned i) -> ReassociationIndices { + if (i < axis) { + return i == src_last_dim ? ReassociationIndices{i, i + 1} + : ReassociationIndices{i}; + } + return i == axis ? ReassociationIndices{i, i + 1} + : ReassociationIndices{i + 1}; + }; + + reassociation = llvm::to_vector( + llvm::map_range(llvm::seq(0, src_last_dim + 1), map_func)); + + auto expandShapeOp = rewriter.create( + op.getLoc(), op.getResult().getType(), src, reassociation); + rewriter.replaceOp(op, expandShapeOp.getResult()); + return success(); +} + +LogicalResult +ClampFConverter::matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto input = adaptor.getX(); + auto min_para = adaptor.getMin(); + auto max_para = adaptor.getMax(); + auto propagateNan_para = adaptor.getPropagateNan(); + + if (auto input_type = dyn_cast(input.getType())) { + if (isa(min_para.getType())) { + auto minEmptyTensor = rewriter.create( + loc, input_type.getShape(), input_type.getElementType()); + min_para = rewriter + .create(loc, ValueRange{min_para}, + ValueRange{minEmptyTensor}) + .result(); + } + if (isa(max_para.getType())) { + auto maxEmptyTensor = rewriter.create( + loc, input_type.getShape(), input_type.getElementType()); + max_para = rewriter + .create(loc, ValueRange{max_para}, + ValueRange{maxEmptyTensor}) + .result(); + } + } + + if (propagateNan_para == PropagateNan::NONE) { + auto minOp = rewriter.create(loc, input, max_para); + auto maxOp = rewriter.create(loc, min_para, minOp); + rewriter.replaceOp(op, ValueRange{maxOp}); + } else if (propagateNan_para == PropagateNan::ALL) { + auto minOp = rewriter.create(loc, input, max_para); + auto maxOp = rewriter.create(loc, min_para, minOp); + rewriter.replaceOp(op, ValueRange{maxOp}); + } else { + return failure(); + } + + return success(); +} + +// Here convert tt.broadcast to linalg.broadcast +// +// before +// %out = tt.broadcast %in : tensor<1x4x8xf32> -> tensor<128x4x8xf32> +// +// after +// %collpased = tensor.collapse_shape %in [[0, 1], [2]] : +// tensor<1x4x8xf32> into tensor<4x8xf32> +// %out = linalg.broadcast ins(%collpased : tensor<4x8xf32>) +// outs(%empty : tensor<128x4x8xf32>) dimensions = [0] +LogicalResult +BroadcastConverter::matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(op->getNumResults() == 1 && "BroadcastOp assumes single result"); + + RankedTensorType sourceType = + cast(adaptor.getSrc().getType()); + RankedTensorType resultType = cast(op.getType()); + auto elementType = resultType.getElementType(); + size_t resultRank = resultType.getRank(); + auto loc = op.getLoc(); + + auto initEmpty = + rewriter.create(loc, resultType.getShape(), elementType); + + SmallVector broadcastDims = + ConverterUtils::getBroadcastDims(sourceType, resultType); + SmallVector unbroadcastDims = + ConverterUtils::getUnbroadcastDims(sourceType, resultType); + + SmallVector collapseReassociationIndices; + auto collapseReassociationIndicesOptional = + getReassociationIndicesForCollapse(sourceType.getShape(), + unbroadcastDims); + if (!collapseReassociationIndicesOptional.has_value()) { + return rewriter.notifyMatchFailure( + op, "Failure with getReassociationIndicesForCollapse call"); + } + collapseReassociationIndices = collapseReassociationIndicesOptional.value(); + + RankedTensorType collapseResultType = + RankedTensorType::get(unbroadcastDims, sourceType.getElementType()); + + auto collpasedOp = rewriter.create( + loc, collapseResultType, adaptor.getSrc(), collapseReassociationIndices); + + auto broadcastOp = rewriter.create( + loc, collpasedOp, initEmpty, + rewriter.getDenseI64ArrayAttr(broadcastDims)); + + rewriter.replaceOp(op, broadcastOp.getResults()); + return success(); +} + +// Reduce Converter +llvm::SmallVector +ReduceConverter::getRedOps(triton::ReduceOp redOp) const { + auto reduceBlock = redOp.getBody(); + return llvm::map_to_vector(reduceBlock->without_terminator(), + [](Operation &op) { return &op; }); +} + +bool ReduceConverter::isReductionOpSupported(Operation *redOp) const { + return isa(redOp); +} + +arith::ConstantOp +ReduceConverter::getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, Type constantType) const { + const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); + + auto attr = llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::AddIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::OrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::AndIOp) { + return rewriter.getIntegerAttr(constantType, 1); + }) + .Case([&](arith::XOrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not supported yet"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, + attr); +} + +bool ReduceConverter::requiresF32Conversion(const Type elemType, + Operation *redOp) const { + return isa(elemType) && + elemType.getIntOrFloatBitWidth() < + Float32Type::get(elemType.getContext()).getWidth() && + isa(redOp); +} + +Value ReduceConverter::getRedElement( + Value lhs, Value rhs, const Location loc, Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const { + return llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + if (convertLhsToF32Precision) { + lhs = b.create(loc, Float32Type::get(b.getContext()), + lhs); + } + return b.create(loc, lhs, rhs); + }) + .Case( + [&](auto redOp) { return b.create(loc, lhs, rhs); }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); +} + +LogicalResult ReduceConverter::convertToLinalgReduce( + triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto source = adaptor.getOperands().front(); + auto sourceType = cast(source.getType()); + auto elemType = sourceType.getElementType(); + auto resType = op.getResult().front().getType(); + auto loc = op.getLoc(); + auto reductionOps = getRedOps(op); + + // Reduction of arbitrary operations isn't supported because using the first + // element across the reduction dimension requires us to iterate over a + // subview that skips over each first element. + if (reductionOps.size() != 1 || + !isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with body " + "containing 1 max(i/f) or addf."); + } + + auto rop = reductionOps.front(); + auto axis = op.getAxis(); + auto isVectorReduce = sourceType.getRank() == 1; + + auto constantType = elemType; + + auto accBaseConstOp = getRedBaseConstOp(rewriter, rop, constantType); + Value initTensor; + + if (isVectorReduce) { + auto holder = rewriter.create( + loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = rewriter + .create(loc, accBaseConstOp.getResult(), + holder.getResult()) + .getResult(0); + } else { + Value init = rewriter.create( + loc, cast(resType).getShape(), constantType); + initTensor = + rewriter.create(loc, accBaseConstOp.getResult(), init) + .getResult(0); + } + + Value finalResult = + rewriter + .create( + loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = getRedElement(inputs[0], inputs[1], loc, rop, + opBuilder, false); + opBuilder.create(loc, result); + }) + .getResult(0); + + if (sourceType.getRank() == 1) { + finalResult = + rewriter.create(loc, constantType, finalResult); + } + + rewriter.replaceOp(op, finalResult); + return success(); +} + +LogicalResult ReduceConverter::convertToLinalgReduceExtended( + ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto elemTypes = op.getElementTypes(); + + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + + SmallVector outputs; + for (auto i = 0; i < op.getResult().size() && i < elemTypes.size(); i++) { + auto result = dyn_cast(op.getType(i)); + SmallVector resultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(result.getShape())}; + outputs.push_back( + rewriter.create(loc, resultShape, elemTypes[i])); + } + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), + [&](Value val) { return mapping.lookup(val); }); + b.create(loc, results); + }); + + if (failed(addReduceWithIndexAttrIfNeeded(rewriter, linalgOp))) { + return rewriter.notifyMatchFailure(op, "meaningless reduce operation"); + } + + if (isScalarReduce) { + SmallVector reduceResults; + for (auto i = 0; i < linalgOp.getResults().size() && i < elemTypes.size(); + i++) { + reduceResults.push_back(rewriter.create( + loc, elemTypes[i], linalgOp.getResults()[i], ValueRange{})); + } + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + return success(); +} + +LogicalResult +ReduceConverter::matchAndRewrite(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto sourceType = + cast(adaptor.getOperands().front().getType()); + assert(sourceType.hasRank() && "Expected input is " + "ranked"); + + int64_t axis = op.getAxis(); + assert(axis >= 0 && axis < sourceType.getRank() && + "Expected reduction " + "axis is within " + "operand's rank"); + + auto reductionOps = getRedOps(op); + if (reductionOps.size() == 1) { + return convertToLinalgReduce(op, adaptor, rewriter); + } + return convertToLinalgReduceExtended(op, adaptor, rewriter); +} + +LogicalResult ExternElementwiseClOpConverter::matchAndRewrite( + triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + if (!op.getPure()) { + op->emitWarning() << "impure elementwise op!"; + return failure(); + } + if (op.getSymbol().contains("__hmf_")) { + // 1. get or create the declaration of external elementwise function + Type dstTy = op.getResult().getType(); + bool isDstScalar = !isa(dstTy); + Type dstElemTy = + isDstScalar ? dstTy : cast(dstTy).getElementType(); + SmallVector srcElemTys; + SmallVector srcs; + for (auto src : op.getSrcs()) { + if (!isa(src.getType())) { + src = rewriter.create( + op.getLoc(), RankedTensorType::get({(int64_t)1}, src.getType()), + src); + } + srcs.push_back(src); + srcElemTys.push_back( + cast(src.getType()).getElementType()); + } + FunctionType elemFuncType = + FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy}); + auto mod = SymbolTable::getNearestSymbolTable(op); + auto extFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(mod, op.getSymbol())); + if (!extFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&mod->getRegion(0).front()); + extFunc = rewriter.create(rewriter.getUnknownLoc(), + op.getSymbol(), elemFuncType); + extFunc.setPrivate(); + extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), + UnitAttr::get(rewriter.getContext())); + } + assert(isa( + SymbolTable::lookupSymbolIn(mod, op.getSymbol()))); + // 2. prepare the output tensor + Value output; + if (isDstScalar) { + dstTy = RankedTensorType::get({(int64_t)1}, dstElemTy); + } + bool found = false; + for (Value v : srcs) { + if (v.getType() == dstTy) { + found = true; + output = v; + break; + } + } + if (!found) { + output = rewriter.create( + op.getLoc(), cast(dstTy).getShape(), dstElemTy); + } + // 3. create the linalg.map op + auto mapOp = rewriter.create( + loc, + /*inputs=*/srcs, + /*init=*/output, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + auto elemOp = builder.create(loc, + /*name=*/op.getSymbol(), + /*resultType=*/dstElemTy, + /*operands=*/regionArgs); + builder.create(loc, elemOp->getResults()); + }); + if (isDstScalar) { + // need to convert tensor back to scalar + auto indexType = rewriter.getIndexType(); + Value zeroConstant = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, 0)); + auto extractOp = rewriter.create( + loc, mapOp.getResults()[0], zeroConstant); + rewriter.replaceOp(op, extractOp); + } else { + rewriter.replaceOp(op, mapOp); + } + return success(); + } + return failure(); +} + +LogicalResult UnrealizedCastConverter::matchAndRewrite( + UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.eraseOp(op); + return success(); +} + +LogicalResult +JoinConverter::matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getLhs(); + Value opb = op.getRhs(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + Value emptyOp = rewriter.create(loc, resType.getShape(), + resType.getElementType()); + + auto shape = dyn_cast(opa.getType()).getShape(); + auto sizes = llvm::map_to_vector(shape, [&](int64_t t) { + return OpFoldResult(rewriter.getI64IntegerAttr(t)); + }); + sizes.push_back(rewriter.getI64IntegerAttr(1)); + + int64_t rank = resType.getRank(); + + // Set last dimension stride to 2 in layout + // As last dimension size is always 1, last dimension stride here could be + // either 1 or 2, while stride `2` could carry interleave trait and it's + // convenient for next lower. + SmallVector strides(rank, rewriter.getIndexAttr(1)); + strides.back() = rewriter.getIndexAttr(2); + + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + + auto insert0 = rewriter.create( + loc, opa, emptyOp, offsets, sizes, strides); + + offsets.back() = rewriter.getIndexAttr(1); + auto insert1 = rewriter.create( + loc, opb, insert0, offsets, sizes, strides); + rewriter.replaceOp(op, insert1); + return success(); +} + +LogicalResult +CatConverter::matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getLhs(); + Value opb = op.getRhs(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + auto emptyOp = rewriter.create(loc, resType.getShape(), + resType.getElementType()); + + auto rank = resType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + + auto inputType = dyn_cast(opa.getType()); + + SmallVector sizes = + llvm::map_to_vector(inputType.getShape(), [&](int64_t t) { + return OpFoldResult(rewriter.getI64IntegerAttr(t)); + }); + + auto insert0 = rewriter.create( + loc, opa, emptyOp, offsets, sizes, strides); + + offsets[0] = + rewriter.getIndexAttr(inputType.getRank() ? inputType.getShape()[0] : 1); + auto insert1 = rewriter.create( + loc, opb, insert0, offsets, sizes, strides); + + rewriter.replaceOp(op, insert1); + return success(); +} + +/// @brief Convert tt.gather to func.call. BiShengIR captures the func +/// with assumed semantics. +/// @param op The `triton::GatherOp` operation to be rewritten. +/// @param adaptor An adaptor for the operation's operands. +/// @param rewriter A pattern rewriter used to modify the IR. +/// @return A `LogicalResult` indicating whether the rewrite was successful. +LogicalResult +GatherConverter::matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = adaptor.getSrc(); + Value idx = adaptor.getIndices(); + Value res = op.getResult(); + auto gatherAxis = op.getAxis(); + + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + + llvm::SmallString<128> funcName = gatherFuncNameBase; + int uniqueId = 0; + while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) { + funcName += "_" + std::to_string(uniqueId++); + } + + auto resTy = res.getType(); + auto libFnType = rewriter.getFunctionType( + {src.getType(), idx.getType(), rewriter.getI32Type()}, {resTy}); + auto funcOp = rewriter.create(loc, funcName.str(), libFnType); + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + + rewriter.setInsertionPoint(op); + Value axis = rewriter.create(loc, gatherAxis, 32); + auto callOp = rewriter.create(loc, funcOp.getSymNameAttr(), + TypeRange({resTy}), + ValueRange({src, idx, axis})); + + rewriter.replaceOp(op, callOp); + + return success(); +} + +LogicalResult +SplitConverter::matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = op.getSrc(); + auto loc = op.getLoc(); + auto inputType = cast(input.getType()); + + int64_t rank = inputType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + // Similar to JoinConverter, here adjust last dimension stride + SmallVector strides(rank, rewriter.getIndexAttr(1)); + strides.back() = rewriter.getIndexAttr(2); + + auto outType = dyn_cast(op.getOutLHS().getType()); + auto sizes = llvm::map_to_vector(outType.getShape(), [&](int64_t t) { + return OpFoldResult(rewriter.getIndexAttr(t)); + }); + sizes.push_back(rewriter.getIndexAttr(1)); + + auto slice0 = rewriter.create( + loc, outType, input, offsets, sizes, strides); + + offsets.back() = rewriter.getIndexAttr(1); + auto slice1 = rewriter.create( + loc, outType, input, offsets, sizes, strides); + + SmallVector slices = {slice0.getResult(), slice1.getResult()}; + rewriter.replaceOp(op, ValueRange(slices)); + return success(); +} + +/* +the element-wise most significant N bits of the 2N-bit product of x and y +%x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi32> +*/ +LogicalResult TritonMulhiuiConverter::matchAndRewrite( + triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value opl = op.getX(); + Value opr = op.getY(); + Value res = op.getResult(); + auto newMulOp = rewriter.create( + loc, res.getType(), res.getType(), opl, opr); + // triton only need the high value + rewriter.replaceOp(op, ValueRange{newMulOp.getHigh()}); + return success(); +} + +LogicalResult TritonPreciseSqrtConverter::matchAndRewrite( + triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); +} + +LogicalResult DevicePrintConverter::matchAndRewrite( + triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + SmallVector inputTypes; + for (auto arg : op.getArgs()) { + inputTypes.push_back(arg.getType()); + } + auto libFnType = rewriter.getFunctionType(inputTypes, {}); + auto funcOp = + rewriter.create(op.getLoc(), printFuncNameBase, libFnType); + SymbolTable symTab(moduleOp); + auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab}); + if (failed(maybePrintFuncNameAttr)) { + return op->emitError( + "failed to create a unique func name for device_print"); + } + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + auto prefixAttr = op.getPrefixAttr(); + funcOp->setAttr(prefixAttrName, prefixAttr); + auto hexAttr = op.getHexAttr(); + funcOp->setAttr(hexAttrName, hexAttr); + + rewriter.setInsertionPoint(op); + rewriter.create(op.getLoc(), funcOp, op.getArgs()); + + rewriter.eraseOp(op); + return success(); +} + +LogicalResult +MatmulConverter::matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto opa = adaptor.getA(); + auto opb = adaptor.getB(); + auto opc = adaptor.getC(); + auto dstType = cast(op.getType()); + auto inputPrec = op.getInputPrecision(); + + if (dstType.getRank() == 2) { + auto matmulOp = rewriter.replaceOpWithNewOp( + op, ValueRange{opa, opb}, ValueRange{opc}); + matmulOp->setAttr( + "input_precison", + rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); + } else if (dstType.getRank() == 3) { + auto matmulOp = rewriter.replaceOpWithNewOp( + op, ValueRange{opa, opb}, ValueRange{opc}); + matmulOp->setAttr( + "input_precison", + rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); + } else { + llvm_unreachable("Datatype of DotOp operands could only be 2D or 3D"); + } + return success(); +} +} // namespace TTOpConverters diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp new file mode 100644 index 000000000..9f7959074 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp @@ -0,0 +1,544 @@ +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "TritonToLinalg/ArgMinMaxConverter.h" +#include "TritonToLinalg/FunctionConverter.h" +#include "TritonToLinalg/LoadStoreConverter.h" +#include "TritonToLinalg/TritonOpConverter.h" +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +#include +#include + +#define DEBUG_TYPE "triton-to-linalg" + +using namespace mlir; +using namespace triton; + +TritonTypeConverter::TritonTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion([](triton::PointerType ptrType) { + return MemRefType::get({ShapedType::kDynamic}, ptrType.getPointeeType()); + }); + + addConversion([](TensorType tensorType) -> Type { + auto elemType = tensorType.getElementType(); + if (auto ptrType = dyn_cast(elemType)) { + elemType = ptrType.getPointeeType(); + } + return MemRefType::get(tensorType.getShape(), elemType); + }); +} + +void TritonToLinalgPass::addProgramInfo(triton::FuncOp func, + bool globalKernel) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // 如果需要,给参数新增属性 + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // 添加对应参数到函数体中 + for (unsigned i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + + if (globalKernel) { + func->setAttr(globalKernelAttr, b.getStringAttr("")); + } else { + func->setAttr(globalKernelAttr, b.getStringAttr("local")); + } +} + +void TritonToLinalgPass::convertTTFunc(triton::FuncOp func, + const bool existDot) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + // bit-casted tt.ptr的特殊处理 + SmallVector inputTypes{type.getInputs()}; + SmallVector retTypes{type.getResults()}; + if (func.getSymVisibility() == "public" && !func.isDeclaration()) { + for (size_t i = 0; i < func.getNumArguments(); ++i) { + auto arg = func.getArgument(i); + if (!isa(arg.getType())) { + continue; + } + // FIXME: Why arg.getUsers() cannot return the user inside scf.for? + llvm::SmallVector arg_users; + func.walk([&](Operation *op) { + if (op->use_empty()) { + return WalkResult::advance(); + } + for (auto operand : op->getOperands()) { + if (operand == arg) { + arg_users.push_back(op); + } + } + return WalkResult::advance(); + }); + + bool arg_use_empty = arg_users.size() == 0; + if (!arg_use_empty) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << arg << " has users:\n"; + int cnt = 0; + for (auto it : arg_users) { + os << "users[" << cnt++ << "] = " << *it; + } + }); + if (llvm::all_of(arg_users, [](Operation *userOp) { + return isa(userOp); + })) { + auto castOp = cast(*arg_users.begin()); + if (castOp.getInputs().size() == 1 && + castOp.getOutputs().size() == 1) { + arg.setType(castOp.getOutputs()[0].getType()); + inputTypes[i] = arg.getType(); + } + } + } else { + // Process unused bool ptr type specially, which guarantees bool pointer + // argument's type is realistic and don't mislead backend compiler. + BaseMemRefType argType = dyn_cast(arg.getType()); + if (argType.getElementTypeBitWidth() == 1) { + // realistic memory layout of bool pointer is 8 bit width + auto memType = argType.cloneWith(std::nullopt, builder.getI8Type()); + arg.setType(memType); + inputTypes[i] = arg.getType(); + } + } + } + } + auto castType = FunctionType::get(func.getContext(), inputTypes, retTypes); + + auto funcFunc = builder.create(func.getLoc(), name, castType); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + auto kernelAttr = func->getAttr(globalKernelAttr); + if (kernelAttr) { + funcFunc->setAttr(globalKernelAttr, kernelAttr); + } + std::string kernelMixMode = "aiv"; + if (existDot) { + // mix also works for pure cube kernel by using the same MAGIC_ELF keyword + kernelMixMode = "mix"; + } + // Set mix_mode in the func attrs so that the backend could know + // the mix_mode by parse the func attrs. + // The backend needs to know the mix_mode because the host wrapper + // needs to set the devbin.magic. Check npu_utils.cpp. + funcFunc->setAttr(kernelMixModeName, builder.getStringAttr(kernelMixMode)); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + func.erase(); +} + +void TritonToLinalgPass::addDynamicLegal( + ConversionTarget &target, TritonTypeConverter &tritonTypeConverter) { + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect>(); + + // add legal dialect on condition + target.addLegalOp(); + + // 根据条件判断需要转换的OP + target.addDynamicallyLegalOp( + [](mlir::Operation *op) { + if (op->use_empty()) { + return false; + } else { + return true; + } + }); + + target.addDynamicallyLegalOp([&](triton::FuncOp op) { + return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); + }); + + target.addDynamicallyLegalOp([](arith::ConstantOp op) { + auto res = op.getResult(); + if (!isa(res.getType())) { + return true; + } + + if (auto denseAttr = dyn_cast(op.getValue())) { + if (!denseAttr.isSplat() || + !isa(denseAttr.getElementType())) { + return true; + } + if (res.hasOneUse() && isa(*res.user_begin())) { + return true; + } + return false; + } + return true; + }); + + target.addDynamicallyLegalOp([](Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type t) { + if (isa(t)) { + return false; + } + if (auto shapedType = dyn_cast(t)) { + return shapedType.getElementType().isIntOrFloat(); + } + assert(t.isIntOrIndexOrFloat()); + return true; + }); + }); + + target.addDynamicallyLegalDialect( + [this](Operation *op) { + if (op->hasAttr("MetaUse")) { + return false; + } + + if (isa(op)) { + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), + [](Type type) { return isa(type); }); + + return this->namedOps || !operateOnTensors; + }); +} + +void TritonToLinalgPass::populateTritonToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add, + LoadStoreConverter::LoadStoreCanonicalizer>( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer + // By test, the following ops do not need canonicalization. + // TTOpConverters::ScalarMathCanonicalizer + // TTOpConverters::ScalarMathCanonicalizer + // TTOpConverters::ScalarMathCanonicalizer + >(patterns.getContext()); +} + +void TritonToLinalgPass::populateTritonToLinalgConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + unsigned int launchGridRank) { + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + // reduce converters + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + if (!this->namedOps) { + linalg::populateElementwiseToLinalgConversionPatterns(patterns); + } +} + +void TritonToLinalgPass::getDependentDialects(DialectRegistry ®istry) const { + registry.insert(); +} + +void TritonToLinalgPass::runOnOperation() { + auto moduleOp = getOperation(); + + // Check if the kernel contains tl.dot. Without tl.dot, + // the kernel would be pure AIV kernel. + bool existDot = false; + moduleOp.walk([&](triton::DotOp dotOp) { + existDot = true; + return WalkResult::interrupt(); + }); + + RewritePatternSet canonicalizerPatterns(&getContext()); + // 1.标准化 LoadStore ScalarStoreCanonicalizer + this->populateTritonToLinalgCanonicalizationPatterns(canonicalizerPatterns); + if (failed(applyPatternsAndFoldGreedily(moduleOp, + std::move(canonicalizerPatterns)))) { + moduleOp->emitError("failed to apply Canonicalizer Patterns"); + signalPassFailure(); + } + + // 2.使用分析 + moduleOp.walk([this](triton::FuncOp op) { + if (failed(runUseAnalysis(op))) { + signalPassFailure(); + } + }); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonTypeConverter tritonTypeConverter{}; + + // 3.标注合法方言 + this->addDynamicLegal(target, tritonTypeConverter); + + // 5.对非法Op注册Converter + this->populateTritonToLinalgConversionPatterns(tritonTypeConverter, patterns, + LAUNCH_GRID_RANK); + + // 6.遍历kernel中的function,修改program id、number of programs参数 + for (auto func : getOperation().getOps()) { + addProgramInfo(func, globalKernel); + } + + // 7.做Op转换 + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + moduleOp->emitError("failed to apply Convertion Patterns"); + signalPassFailure(); + } + + // 8.函数头尾转换 + moduleOp.walk( + [&](triton::FuncOp func) { this->convertTTFunc(func, existDot); }); + + // 9.清除无效代码,简化代码。 + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + + // Try interleave optimization + llvm::DenseMap> interleaveCandidate; + llvm::DenseMap> + interleaveCandidateWithMask; + moduleOp.walk([&](bufferization::MaterializeInDestinationOp materializeOp) { + if (auto reinterpretCastOp = + materializeOp.getDest() + .getDefiningOp()) { + if (llvm::isa(reinterpretCastOp.getSource()) && + reinterpretCastOp.getStaticStrides().back() == 2) { + interleaveCandidate[llvm::cast( + reinterpretCastOp.getSource())] + .push_back(materializeOp); + } + } + + // Difference is that converted op chain of store with mask has + // `memref::SubViewOp` + if (auto subviewOp = + materializeOp.getDest().getDefiningOp()) { + if (!llvm::isa( + materializeOp.getSource().getDefiningOp())) + return WalkResult::advance(); + + if (auto reinterpretCastOp = + subviewOp.getSource() + .getDefiningOp()) { + if (llvm::isa(reinterpretCastOp.getSource()) && + reinterpretCastOp.getStaticStrides().back() == 2) { + interleaveCandidateWithMask[llvm::cast( + reinterpretCastOp.getSource())] + .push_back(materializeOp); + } + } + } + + return WalkResult::advance(); + }); + + for (auto [blockArg, materializeVec] : interleaveCandidate) { + // Just enable optimization where exists double materializeOp with same + // block argument destination. + if (materializeVec.size() != 2) + continue; + auto result = InterleaveStatusOptimization(materializeVec); + } + + for (auto [blockArg, materializeVec] : interleaveCandidateWithMask) { + if (materializeVec.size() != 2) + continue; + auto result = InterleaveStatusWithMaskOptimization(materializeVec); + } + + // Force to add an argument at the beginning of function arguments, which + // represents stub arg for workspace. Default type is memref + for (auto func : getOperation().getOps()) { + if (!func->hasAttr("global_kernel")) + continue; + + auto context = func.getContext(); + constexpr int64_t workspaceArgIdx = 0; + MemRefType workspaceArgType = + MemRefType::get(SmallVector(1, ShapedType::kDynamic), + IntegerType::get(context, 8)); + NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"), + UnitAttr::get(context)); + + func.insertArgument(/*argIndex*/ workspaceArgIdx, + /*argType*/ workspaceArgType, + /*dicAttr*/ nullptr, func->getLoc()); + func->setAttr("WorkspaceArgIdx", + IntegerAttr::get(IntegerType::get(&getContext(), 64), 0)); + } + + // Fix the Location info + moduleOp.walk([&](Operation *op) { + auto loc = op->getLoc(); + if (isa(loc)) { + llvm::SmallPtrSet stopOps; + traverseForwardUpdateUserChainIf( + op, + /*conditionFn*/ + [](Operation *curOp) { return false; }, + /*stopFn*/ + [](Operation *curOp) { return !isa(curOp->getLoc()); }, + /*actionFn*/ + nullptr, stopOps); + if (stopOps.empty()) { + op->emitWarning() << *op << " and its users all have no location!"; + } else { + Operation *goodOp = *stopOps.begin(); + op->setLoc(goodOp->getLoc()); + } + } + return WalkResult::advance(); + }); +} + +std::unique_ptr> triton::createTritonToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp new file mode 100644 index 000000000..4b096316d --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp @@ -0,0 +1,362 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace triton; +using namespace dataflow; + +#define DEBUG_TYPE "triton-use-analysis" + +std::string stringifyUseType(UseType useTy) { + std::string ret; + if (useTy == UseType::MetaUse) { + ret = "MetaUse"; + } else if (useTy == UseType::DataUse) { + ret = "DataUse"; + } else if (useTy == UseType::MixUse) { + ret = "MixUse"; + } else if (useTy == UseType::Undefined) { + ret = "Undefined"; + } + return ret; +} + +#if LLVM_VERSION_MAJOR >= 20 +LogicalResult +triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) { +#else +void triton::UseAnalysis::visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { +#endif + + if (op->getResults().size() == 1) { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (resultType && isa(resultType.getElementType())) { + for (auto opnd : operands) { + propagateUse(opnd, UseType::MetaUse); + } + } + } + + TypeSwitch(op) + .Case([&](auto load) { + propagateUse(operands[0], UseType::MetaUse); + auto mask = load.getMask(); + auto other = load.getOther(); + if (mask) { + assert(mask != other && "mask and other cannot be the same"); + propagateUse(operands[1], UseType::MetaUse); + } + if (other) { + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto store) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = store.getValue(); + auto mask = store.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + // Consider triton::AtomicRMWOp as store operation + .Case([&](auto atomicOp) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = atomicOp.getVal(); + auto mask = atomicOp.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto dot) { + propagateResults(operands[0], results); + propagateResults(operands[1], results); + + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } + + if (opc && splat && splat.getSrc().getDefiningOp()) { + propagateUse(operands[2], UseType::MetaUse); + } else { + propagateUse(operands[2], UseType::DataUse); + } + }) + .Default([&](Operation *op) { + // this condition account for tt.addptr + for (auto operand : operands) { + propagateResults(operand, results); + } + }); +#if LLVM_VERSION_MAJOR >= 20 + return success(); +#endif +} + +LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { + MLIRContext *context = funcOp.getContext(); + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(funcOp))) { + return failure(); + } + auto &os = llvm::dbgs(); + // Walk the func op, convert tags on operands to tags on operations + funcOp.walk([&](Operation *op) { + LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; }); + UseType useType = UseType::Undefined; + for (auto result : op->getResults()) { + LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; }); + auto use = solver.lookupState(result); + assert(use && "Lattice value not found"); + auto thisUseType = use->type; + LLVM_DEBUG({ + os << "[UseAnalysis] ==========> useType is " + << stringifyUseType(thisUseType) << "\n"; + }); + if (thisUseType == UseType::Undefined) { + continue; + } + if (useType == UseType::Undefined) { + useType = thisUseType; + } + if (thisUseType == UseType::MixUse || thisUseType != useType) { + useType = UseType::MixUse; + break; + } + } + + if (useType == UseType::Undefined) { + LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); + return; + } else if (useType == UseType::MetaUse) { + if (!isa(op)) { + assert(op->getNumResults() == 1 && + "Ops used for meta computation are expected to have one result"); + } + for (auto it = 0; it < op->getNumResults(); ++it) { + // Only set the tag if the operation uses tensors + if (isa(op->getResult(it).getType()) || + (isa(op) && + isa(op->getResult(it).getType()))) { + // Setting tag for erasing op later + op->setAttr("MetaUse", UnitAttr::get(context)); + } + } + return; + } else if (useType == UseType::DataUse) { + LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); + return; + } + + assert(useType == UseType::MixUse); + + // If the operation only produces scalars, no need to clone it + bool shapedResult = true; + for (auto result : op->getResults()) + shapedResult &= isa(result.getType()); + if (!shapedResult) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + llvm::SetVector metaUsers; + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + TypeSwitch(user) + .Case([&](auto load) { + auto ptr = load.getPtr(); + auto mask = load.getMask(); + auto other = load.getOther(); + if (result == ptr || result == mask || result == other) { + metaUsers.insert(user); + } + }) + .Case([&](auto store) { + auto ptr = store.getPtr(); + auto mask = store.getMask(); + if (result == ptr || result == mask) { + metaUsers.insert(user); + } + }) + .Case([&](auto atomicOp) { + auto ptr = atomicOp.getPtr(); + auto mask = atomicOp.getMask(); + if (result == ptr || result == mask) + metaUsers.insert(user); + }) + .Case([&](auto dot) { + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } + + if (opc && splat && + splat.getSrc().getDefiningOp()) { + metaUsers.insert(user); + } + }) + .Default([&](Operation *op) { + bool allMeta = true; + for (auto res : op->getResults()) { + auto resUse = solver.lookupState(res); + if (resUse->type != UseType::MetaUse) { + allMeta = false; + break; + } + } + if (allMeta) { + metaUsers.insert(user); + } + }); + } + } + + // If the operation doesn't have direct meta users, no need to clone it + if (metaUsers.empty()) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Clone the operation; switch all meta users to use the clone + OpBuilder builder(op); + auto clone = builder.clone(*op); + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + + // Setting tag for erasing op later + clone->setAttr("MetaUse", UnitAttr::get(context)); + + for (auto [res_i, result] : llvm::enumerate(op->getResults())) { + for (auto user : metaUsers) { + for (auto &operand : user->getOpOperands()) { + if (operand.get() == result) { + operand.set(clone->getResult(res_i)); + } + } + } + } + }); + LLVM_DEBUG({ + os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n"; + }); + // Post-process + funcOp.walk([&](Operation *op) { + // Handle indirect load case. + // For example, load(1st) -> computeOp -> load(2nd). + // The first load is IndirectLoadInterfaceOp. + // Do not inplace replace MetaUse by MixUse. Because the condition checking + // depends on that the op has the attr of MetaUse. + // Handle the indirect load interface op + // We first trace from the 1st load to the 2nd load with the ops between + // them marked as MixUse. Then we traceback from the 2nd load to mark defs + // MixUse. + if (opIsIndirectLoad(op) || opIsIndirectCalc(op)) { + LLVM_DEBUG({ + os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n"; + }); + llvm::SmallPtrSet stopOps; + // Modify the users of this op's result. + traverseForwardUpdateUserChainIf( + op, + /*conditionFn*/ + [op](Operation *curOp) { return isMetaUse(curOp) && curOp != op; }, + /*stopFn*/ + [&](Operation *curOp) { + // triton::LoadOp without MetaUse means it is an indirect load + // instead of the load providing the offset. + // The pattern is as follows, + // load -> ops -> load + // We need to ensure the intermediate ops are marked MixUse + // so that they will be replaced instead of be erased without + // conversion. + return isa(curOp) && !curOp->hasAttr("MetaUse"); + }, + /*actionFn*/ + [](OpBuilder &b, Operation *op) { + op->setAttr("MixUse", UnitAttr::get(op->getContext())); + }, + stopOps); + LLVM_DEBUG({ + os << "[UseAnalysis] stopOps are \n"; + int i = 0; + for (auto it = stopOps.begin(); it != stopOps.end(); it++) { + os << i++ << ": " << *(*it) << "\n"; + } + }); + LLVM_DEBUG({ + os << "[UseAnalysis] After trace, funcOp is " << *funcOp << "\n"; + }); + for (auto it = stopOps.begin(); it != stopOps.end(); it++) { + auto stopOp = *it; + traverseBackwardUpdateOperandChainIf( + stopOp, + [stopOp](Operation *curOp) { + return isMetaUse(curOp) && curOp != stopOp; + }, + [](OpBuilder &b, Operation *op) { + op->setAttr("MixUse", UnitAttr::get(op->getContext())); + }); + } + LLVM_DEBUG({ + os << "[UseAnalysis] After traceback of stopOp, funcOp is " << *funcOp + << "\n"; + }); + // Modify this op. + op->setAttr("MixUse", UnitAttr::get(op->getContext())); + } + }); + // Remove MetaUse in case of MixUse existing in the op + funcOp.walk([&](Operation *op) { + if (isMetaUse(op) && isMixUse(op)) { + op->removeAttr("MetaUse"); + } + }); + LLVM_DEBUG({ + os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n"; + }); + return success(); +} + +MetaUseEraser::MetaUseEraser(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + +LogicalResult MetaUseEraser::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + if (isa(op)) { + return rewriter.notifyMatchFailure(op, + "AddPtrOp will be handled separately"); + } + if (isMetaUse(op)) { + rewriter.eraseOp(op); + return success(); + } + return rewriter.notifyMatchFailure(op, "requires meta ops"); +} diff --git a/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt new file mode 100644 index 000000000..b6aa5164b --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(MLIRTritonNPUUtils + Utils.cpp + InterleaveOptimization.cpp + + LINK_LIBS PUBLIC + MLIRIR + TritonIR +) diff --git a/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp b/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp new file mode 100644 index 000000000..ec3e4d3d5 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp @@ -0,0 +1,662 @@ +//===- InterleaveOptimization.cpp -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/IR/Operation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace mlir { +namespace triton { +// For origin MemRefType of ReinterpretCastOp under interleave state, here wanna +// adjust its shape info by expanding last dimension double. +MemRefType expandInterleaveMemRefType(MemRefType originType) { + // Double the last dimension shape + SmallVector shape(originType.getShape()); + shape.back() = shape.back() * 2; + + // Adjuest layout attribute + StridedLayoutAttr originLayout = + llvm::dyn_cast(originType.getLayout()); + // If offset is static, just reset it to 0 + auto offset = originLayout.getOffset() == ShapedType::kDynamic + ? originLayout.getOffset() + : 0; + // Set last dimension stride to 1 + SmallVector stride(originLayout.getStrides()); + stride.back() = 1; + + return MemRefType::get( + shape, originType.getElementType(), + StridedLayoutAttr::get(originType.getContext(), offset, stride)); +} + +// ********************* +// ** NOTE ** +// ********************* +// How to determine new offset is a little tricky and specific +// Here just consider this state in triton language: +// +// dim_range = tl.arange(0, BLOCK // 2) +// last_dim_even_range = dim_range * 2 +// last_dim_odd_range = dim_range * 2 + 1 +// +// Here `multiply two` represents that last dimension stride is 2, and +// `add constant one` represents whether it's odd index part of +// deinterleave result. +// +// Therefore, how to distinguish interleave/deinterleave on even index or odd +// index is whether last dimension range explicitly `add constant one` without +// any other operation. In IR it's shown that whether defining op of +// `castOffset` is an arith::addOp, as this arith::addOp would contain above +// `add constant one` opeartion after LegacyAddPtrConverter. +// +// Well, index mode should be passed to interleave/deinterleave, in other words, +// `add constant one` should work on offset of next insert_slice/extract_slic. +// The new reinterpretcast just wanna describe whole tensor, so new castOffset +// is just from non-last diemsnion accumulation and remove `add constant one` +std::pair +recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder) { + // To trace value type offset + std::function traceOffset = [&](Operation *op) -> bool { + // Consider constant one in `add constant one` operation + if (llvm::isa(op)) + return false; + + if (llvm::isa(op)) { + auto addOp = llvm::cast(op); + if (auto constLHS = addOp.getLhs().getDefiningOp()) { + assert(dyn_cast(constLHS.getValueAttr()).getInt() == 1 && + "Arith::constant value of addi's operand must be 1 when " + "calculate deinterleave offset"); + return false; + } + if (auto constRHS = addOp.getRhs().getDefiningOp()) { + assert(dyn_cast(constRHS.getValueAttr()).getInt() == 1 && + "Arith::constant value of addi's operand must be 1 when " + "calculate deinterleave offset"); + return false; + } + } + return true; + }; + + IndexMode evenOrOdd = IndexMode::EVEN_MODE; + // Reuse origin offset if there's no 'add constant one' + OpFoldResult newOffset = originOffset; + if (llvm::isa(originOffset)) { + // If offset is constant int(IndexAttr), + // the int value could only be 0 or 1 + int64_t intOffset = getConstantIntValue(originOffset).value(); + assert((intOffset == 0 || intOffset == 1)); + if (intOffset == 1) { + evenOrOdd = IndexMode::ODD_MODE; + newOffset = builder.getIndexAttr(0); + } + } else if (llvm::isa(originOffset)) { + if (!traceOffset(originOffset.get().getDefiningOp())) { + evenOrOdd = IndexMode::ODD_MODE; + Operation *traceResult = findFirstMatchingOperandDef( + originOffset.get().getDefiningOp(), traceOffset); + assert(traceResult->getNumResults() == 1 && + "Offset defining operation must have one result"); + newOffset = traceResult->getResult(0); + } + } + + return {newOffset, evenOrOdd}; +} + +LogicalResult +DeinterleaveStatusOptimization(triton::LoadOp op, + triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) { + auto ptr = adaptor.getPtr(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto loc = op.getLoc(); + + // 1. Get new source memref type + auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); + + // 2. Create new ReinterpretCastOp + auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); + auto castSize = reinterpretCast.getConstifiedMixedSizes(); + auto castStride = reinterpretCast.getConstifiedMixedStrides(); + // Actually, `castSize` is always constant value as `MemRefType` result + if (auto lastDimSize = makeIntAttr(castSize.back())) { + castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); + } else { + return failure(); + } + // Last element of castStride is also constant value as prerequisite + // is that last dimension stride of casted memref type is always 2. + castStride.back() = rewriter.getIndexAttr(1); + auto [castOffset, indexMode] = + recountReinterpretCastOffset(originCastOffset, rewriter); + auto newCastOp = rewriter.create( + loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, + castStride); + + // 3. Create new memref allocOp + auto newAllocOp = rewriter.create( + loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); + + // 4. Implement memref copy and bufferization back to tensor + rewriter.create(loc, newCastOp.getResult(), newAllocOp); + Value newTensor = rewriter.create( + loc, + RankedTensorType::get(srcType.getShape(), srcType.getElementType()), + newAllocOp, true /* restrict */, true /* writable */); + + // 5. Implement tensor extract_slice to represent deinterleave + // Here use `castOffset` to determine whether even index deinterleave or + // odd index. + SmallVector extractOffsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector extractStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector extractSizes = llvm::to_vector( + llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + // Adjust extract_slice shape + switch (indexMode) { + case IndexMode::EVEN_MODE: + extractOffsets.back() = rewriter.getIndexAttr(0); + break; + case IndexMode::ODD_MODE: + extractOffsets.back() = rewriter.getIndexAttr(1); + break; + } + extractStrides.back() = rewriter.getIndexAttr(2); + extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); + + Value deinterleaveSlice = rewriter.create( + loc, newTensor, extractOffsets, extractSizes, extractStrides); + + rewriter.replaceOp(op, deinterleaveSlice); + return success(); + } + + return failure(); +} + +LogicalResult DeinterleaveStatusWithMaskOptimization( + triton::LoadOp op, triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter, MaskState &mstate, + memref::AllocOp originAllocOp) { + auto ptr = adaptor.getPtr(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto loc = op.getLoc(); + + // 1. Get new source memref type + auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); + + // 2. Create new ReinterpretCastOp + auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); + auto castSize = reinterpretCast.getConstifiedMixedSizes(); + auto castStride = reinterpretCast.getConstifiedMixedStrides(); + + if (auto lastDimSize = makeIntAttr(castSize.back())) { + castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); + } else { + return failure(); + } + castStride.back() = rewriter.getIndexAttr(1); + auto [castOffset, indexMode] = + recountReinterpretCastOffset(originCastOffset, rewriter); + + auto newCastOp = rewriter.create( + loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, + castStride); + + // 3. Create new memref allocOp + // To reuse existing linalg::fill, here need to change insertion point + auto savedInsertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(originAllocOp); + auto newAllocOp = rewriter.create( + loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); + rewriter.restoreInsertionPoint(savedInsertPoint); + + // 4. Broadcast other value by linalg.fill if necessary + auto other = op.getOther(); + // While deinterleave optimization will just adjust last dimension info + // and origin mask state wouldn't involve last dimension. Therefore in + // current `scf.if + linalg.fill` combination, condition of `if` could be + // kept and just replace linalg.fill' + if (other) { + assert(originAllocOp->hasOneUse() && + llvm::isa(*(originAllocOp->getUsers().begin()))); + auto originFillOp = + llvm::dyn_cast(*(originAllocOp->getUsers().begin())); + + assert(llvm::isa(originFillOp->getParentOp())); + auto ifOp = llvm::dyn_cast(originFillOp->getParentOp()); + + auto newFillOp = ifOp.getThenBodyBuilder().create( + originFillOp.getLoc(), originFillOp.getInputs(), + ValueRange{newAllocOp}); + rewriter.eraseOp(originFillOp); + } + + // 5. Implement new subview, memref copy and bufferization back to tensor + SmallVector subviewStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector subviewOffsets = mstate.offsets; + SmallVector subviewSizes = mstate.dims; + // Just adjust last dimension size to double + std::optional originSubviewLastDim = + getConstantIntValue(subviewSizes.back()); + assert(originSubviewLastDim.has_value()); + subviewSizes.back() = + rewriter.getIndexAttr(originSubviewLastDim.value() * 2); + + auto argSubviewType = memref::SubViewOp::inferResultType( + srcType, subviewOffsets, subviewSizes, subviewStrides); + // alloca subview type doesn't carry layout attribute + auto allocSubviewType = memref::SubViewOp::inferResultType( + newAllocOp.getType(), subviewOffsets, subviewSizes, subviewStrides); + + memref::SubViewOp srcSubview = rewriter.create( + loc, llvm::cast(argSubviewType), newCastOp, subviewOffsets, + subviewSizes, subviewStrides); + memref::SubViewOp dstSubview = rewriter.create( + loc, llvm::cast(allocSubviewType), newAllocOp, + subviewOffsets, subviewSizes, subviewStrides); + rewriter.create(loc, srcSubview, dstSubview); + Value newTensor = rewriter.create( + loc, + RankedTensorType::get(srcType.getShape(), srcType.getElementType()), + newAllocOp, true /* restrict */, true /* writable */); + + // 6. Implement tensor extract_slice to represent deinterleave + // Here use `castOffset` to determine whether even index deinterleave or + // odd index. + SmallVector extractOffsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector extractStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector extractSizes = llvm::to_vector( + llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + switch (indexMode) { + case IndexMode::EVEN_MODE: + extractOffsets.back() = rewriter.getIndexAttr(0); + break; + case IndexMode::ODD_MODE: + extractOffsets.back() = rewriter.getIndexAttr(1); + break; + } + extractStrides.back() = rewriter.getIndexAttr(2); + extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); + + Value deinterleaveSlice = rewriter.create( + loc, newTensor, extractOffsets, extractSizes, extractStrides); + + rewriter.replaceOp(op, deinterleaveSlice); + return success(); + } + return failure(); +} + +LogicalResult +InterleaveStatusOptimization(SmallVector materializeVec) { + OpBuilder builder(materializeVec[1]); + auto loc = materializeVec[1]->getLoc(); + + auto firstReinterpretCastOp = + llvm::dyn_cast( + materializeVec[0]) + .getDest() + .getDefiningOp(); + auto secondReinterpretCastOp = + llvm::dyn_cast( + materializeVec[1]) + .getDest() + .getDefiningOp(); + + assert(firstReinterpretCastOp && secondReinterpretCastOp); + + // Judge whether two `ReinterpretCastOp` shape satisfy interleave state + // a. both size are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedSizes(), + secondReinterpretCastOp.getConstifiedMixedSizes())) { + return failure(); + } + // b. both strides are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedStrides(), + secondReinterpretCastOp.getConstifiedMixedStrides())) { + return failure(); + } + // c. both offsets should satisfy tricky rule + auto firstOriginCastOffset = + firstReinterpretCastOp.getConstifiedMixedOffset(); + auto secondOriginCastOffset = + secondReinterpretCastOp.getConstifiedMixedOffset(); + std::pair indexModeRecord; + OpFoldResult newCastOffset; + if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) + return failure(); + newCastOffset = builder.getIndexAttr(0); + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ + static_cast(secondIndexMode)) || + (llvm::dyn_cast(firstCastOffset) != + llvm::dyn_cast(secondCastOffset))) + return failure(); + + if (firstIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(firstCastOffset); + } + if (secondIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(secondCastOffset); + } + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else { + return failure(); + } + + // Create new op + // 1. Get new destination memref type + auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); + + // 2. New tensor::EmptyOp + auto emptyTensor = builder.create(loc, dstType.getShape(), + dstType.getElementType()); + + // 3. New insert_slice from materialization source into new empty tensor + SmallVector insertOffsets(dstType.getRank(), + builder.getIndexAttr(0)); + SmallVector insertStrides(dstType.getRank(), + builder.getIndexAttr(1)); + SmallVector insertSizes = llvm::to_vector( + llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { + return builder.getIndexAttr(dim); + })); + insertStrides.back() = builder.getIndexAttr(2); + insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); + if (indexModeRecord.first == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertFirst = builder.create( + loc, + llvm::dyn_cast( + materializeVec[0]) + .getSource(), + emptyTensor.getResult(), insertOffsets, insertSizes, insertStrides); + + if (indexModeRecord.second == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertSecond = builder.create( + loc, + llvm::dyn_cast( + materializeVec[1]) + .getSource(), + insertFirst.getResult(), insertOffsets, insertSizes, insertStrides); + + // 4. Reinterpret_cast block arg + auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); + auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); + newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); + newCastStride.back() = builder.getIndexAttr(1); + auto newCastOp = builder.create( + loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, + newCastSize, newCastStride); + + // 5. Create new bufferization::MaterializeInDestinationOp + auto newStoreOp = builder.create( + loc, insertSecond.getResult(), newCastOp.getResult()); + // Setting writable is necessary as dst is memref type + newStoreOp.setWritable(true); + + // 6. Erase origin materialization + materializeVec[0]->erase(); + materializeVec[1]->erase(); + + return success(); +} + +LogicalResult +InterleaveStatusWithMaskOptimization(SmallVector materializeVec) { + OpBuilder builder(materializeVec[1]); + + auto firstSubviewOpOfReCast = + llvm::dyn_cast( + materializeVec[0]) + .getDest() + .getDefiningOp(); + auto firstSrcExtractSlice = + llvm::dyn_cast( + materializeVec[0]) + .getSource() + .getDefiningOp(); + auto firstReinterpretCastOp = firstSubviewOpOfReCast.getSource() + .getDefiningOp(); + + auto secondSubviewOpOfReCast = + llvm::dyn_cast( + materializeVec[1]) + .getDest() + .getDefiningOp(); + auto secondSrcExtractSlice = + llvm::dyn_cast( + materializeVec[1]) + .getSource() + .getDefiningOp(); + auto secondReinterpretCastOp = + secondSubviewOpOfReCast.getSource() + .getDefiningOp(); + + // 1. Both source shapes of subview and extract_slice are equal + if (firstSubviewOpOfReCast.getSourceType().getShape() != + firstSrcExtractSlice.getSourceType().getShape()) + return failure(); + if (secondSubviewOpOfReCast.getSourceType().getShape() != + secondSrcExtractSlice.getSourceType().getShape()) + return failure(); + if (firstSubviewOpOfReCast.getSourceType().getShape() != + secondSubviewOpOfReCast.getSourceType().getShape()) + return failure(); + + // 2. both mask state are equal + std::function cmpFunc = + mlir::isEqualConstantIntOrValue; + if (!mlir::detail::sameOffsetsSizesAndStrides(firstSubviewOpOfReCast, + firstSrcExtractSlice, cmpFunc)) + return failure(); + if (!mlir::detail::sameOffsetsSizesAndStrides(secondSubviewOpOfReCast, + secondSrcExtractSlice, cmpFunc)) + return failure(); + if (!mlir::detail::sameOffsetsSizesAndStrides( + firstSubviewOpOfReCast, secondSubviewOpOfReCast, cmpFunc)) + return failure(); + + // 3. Still judge whether two `ReinterpretCastOp` shape satisfy request + // a. both size are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedSizes(), + secondReinterpretCastOp.getConstifiedMixedSizes())) + return failure(); + // b. both strides are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedStrides(), + secondReinterpretCastOp.getConstifiedMixedStrides())) + return failure(); + // c. both offsets should satisfy tricky rule + auto firstOriginCastOffset = + firstReinterpretCastOp.getConstifiedMixedOffset(); + auto secondOriginCastOffset = + secondReinterpretCastOp.getConstifiedMixedOffset(); + std::pair indexModeRecord; + OpFoldResult newCastOffset; + if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) + return failure(); + newCastOffset = builder.getIndexAttr(0); + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ + static_cast(secondIndexMode)) || + (llvm::dyn_cast(firstCastOffset) != + llvm::dyn_cast(secondCastOffset))) + return failure(); + + if (firstIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(firstCastOffset); + } + if (secondIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(secondCastOffset); + } + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else { + return failure(); + } + auto loc = materializeVec[1]->getLoc(); + + // Create new op + // 1. Get new destination memref type + auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); + + // 2. New tensor::EmptyOp + auto emptyTensor = builder.create(loc, dstType.getShape(), + dstType.getElementType()); + + // 3. New insert_slice from extract_slice source into new empty tensor + SmallVector insertOffsets(dstType.getRank(), + builder.getIndexAttr(0)); + SmallVector insertStrides(dstType.getRank(), + builder.getIndexAttr(1)); + SmallVector insertSizes = llvm::to_vector( + llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { + return builder.getIndexAttr(dim); + })); + insertStrides.back() = builder.getIndexAttr(2); + insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); + if (indexModeRecord.first == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertFirst = builder.create( + loc, firstSrcExtractSlice.getSource(), emptyTensor.getResult(), + insertOffsets, insertSizes, insertStrides); + + if (indexModeRecord.second == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertSecond = builder.create( + loc, secondSrcExtractSlice.getSource(), insertFirst.getResult(), + insertOffsets, insertSizes, insertStrides); + + // 4. To enable store with mask, create new extract_slice + SmallVector extractOffsets = + firstSrcExtractSlice.getMixedOffsets(); + SmallVector extractStrides = + firstSrcExtractSlice.getMixedStrides(); + SmallVector extractSizes = firstSrcExtractSlice.getMixedSizes(); + assert(llvm::isa(extractSizes.back())); + extractSizes.back() = builder.getIndexAttr( + getConstantIntValue(extractSizes.back()).value() * 2); + auto newSrcExtractSlice = builder.create( + loc, insertSecond.getResult(), extractOffsets, extractSizes, + extractStrides); + + // 5. Reinterpret_cast block arg + auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); + auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); + newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); + newCastStride.back() = builder.getIndexAttr(1); + auto newCastOp = builder.create( + loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, + newCastSize, newCastStride); + + // 6. Create new memref::SubViewOp of above new reinterpret_cast + // Here could reuse shape info of new extract_slice + auto dstSubviewType = memref::SubViewOp::inferResultType( + dstType, extractOffsets, extractSizes, extractStrides); + auto newSubviewOpOfReCast = builder.create( + loc, llvm::cast(dstSubviewType), newCastOp, extractOffsets, + extractSizes, extractStrides); + + // 7. Create new bufferization::MaterializeInDestinationOp + auto newStoreOp = builder.create( + loc, newSrcExtractSlice.getResult(), newSubviewOpOfReCast.getResult()); + // Setting writable is necessary as dst is memref type + newStoreOp.setWritable(true); + + // 8. Erase origin operation + materializeVec[0]->erase(); + materializeVec[1]->erase(); + firstSubviewOpOfReCast->erase(); + firstSrcExtractSlice->erase(); + secondSubviewOpOfReCast->erase(); + secondSrcExtractSlice->erase(); + + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp b/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp new file mode 100644 index 000000000..805e69025 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp @@ -0,0 +1,752 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "../../include/Utils/Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include + +#define DEBUG_TYPE "TritonNPU-Utils" + +namespace mlir { + +namespace ConverterUtils { + +Value getTransposedValue(Value source, const Location loc, + ConversionPatternRewriter &rewriter, + llvm::ArrayRef order) { + auto sourceType = cast(source.getType()); + auto sourceRank = sourceType.getRank(); + + SmallVector perm(order); + SmallVector originalShape(sourceType.getShape()); + SmallVector transposedShape(sourceRank); + for (size_t i = 0; i < sourceRank; i++) { + transposedShape[i] = originalShape[perm[i]]; + } + + Value transposeInit = rewriter.create( + loc, transposedShape, sourceType.getElementType()); + + Value transpose = + rewriter.create(loc, source, transposeInit, perm) + .getResults()[0]; + + return transpose; +} + +SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = mlir::TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +memref::SubViewOp makeSubViewOp(Value src, + llvm::SmallVectorImpl &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter) { + auto srcType = dyn_cast(src.getType()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dyn_cast(dstType), + src, offsets, sizes, strides); +} + +void getShapeInfo(Value val, llvm::SmallVectorImpl &shapes, + ConversionPatternRewriter &rewriter) { + if (isa(val)) { + auto blockArg = dyn_cast(val); + auto blockOp = blockArg.getOwner()->getParentOp(); + if (isa(blockOp)) { + auto forOp = dyn_cast(blockOp); + auto operand = forOp.getTiedLoopInit(blockArg)->get(); + getShapeInfo(operand, shapes, rewriter); + } else { + emitError(val.getLoc()) + << "getShapeInfo() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + } + return; + } + + if (isa(val.getType())) { + val = rewriter.getRemappedValue(val); + } + + if (!isa(val.getDefiningOp())) { + emitError(val.getLoc()) << "getShapeInfo() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + return; + } + auto castOp = dyn_cast(val.getDefiningOp()); + auto tensorPtrAttr = castOp->getAttr("tensor_ptr_attr"); + if (tensorPtrAttr) { + shapes = castOp.getConstifiedMixedSizes(); + } else { + getShapeInfo(castOp.getSource(), shapes, rewriter); + } + return; +} + +SmallVector +getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, + Value adaptorPtr, const Location &loc, + ConversionPatternRewriter &rewriter) { + SmallVector parTensorShapes; + getShapeInfo(adaptorPtr, parTensorShapes, rewriter); + auto extractOp = + rewriter.create(loc, adaptorPtr); + + OpFoldResult baseOffset = extractOp.getConstifiedMixedOffset(); + SmallVector strides = extractOp.getConstifiedMixedStrides(); + + SmallVector boundarySizes = extractOp.getConstifiedMixedSizes(); + auto dims = boundarySizes.size(); + OpFoldResult currentStride = rewriter.getIndexAttr(1); + for (int i = dims - 1; i >= 0; i--) { + auto offset = divOpFoldResult(baseOffset, currentStride, loc, rewriter); + offset = remOpFoldResult(offset, parTensorShapes[i], loc, rewriter); + if (llvm::find(boundaryCheck, i) != boundaryCheck.end()) { + OpFoldResult subOfr = + subOpFoldResult(parTensorShapes[i], offset, loc, rewriter); + boundarySizes[i] = + minOpFoldResult(boundarySizes[i], subOfr, loc, rewriter); + } + currentStride = + mulOpFoldResult(currentStride, parTensorShapes[i], loc, rewriter); + } + return boundarySizes; +} + +SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst) { + SmallVector broadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] != srcShape[i]) { + assert(srcShape[i] == 1 && + "Size of source broadcast dimension must be 1"); + broadcastDims.push_back(i); + } + } + assert(!broadcastDims.empty() && "Cannot identify broadcast dimension"); + return broadcastDims; +} + +// Dimensions of collapesd tensor is all unbroadcast dims +SmallVector getUnbroadcastDims(RankedTensorType src, + RankedTensorType dst) { + SmallVector unbroadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] == srcShape[i]) { + unbroadcastDims.emplace_back(srcShape[i]); + } + } + return unbroadcastDims; +} + +} // namespace ConverterUtils + +namespace triton { + +mlir::Operation * +findFirstMatchingOperandDef(mlir::Operation *rootOp, + const std::function &condFn) { + LLVM_DEBUG(llvm::dbgs() << "[findFirstMatchingOperandDef] Current op: " + << *rootOp << "\n"); + mlir::Value lhs = nullptr; + mlir::Value rhs = nullptr; + if (auto op = dyn_cast(rootOp)) { + lhs = op.getPtr(); + rhs = op.getOffset(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getSrc(); + } else if (auto op = dyn_cast(rootOp)) { + } else { + rootOp->emitRemark("Backtracing encounters unsupported Operation"); + return nullptr; + } + // Backtrace operands + if (!lhs) { + return nullptr; + } + auto lhsDef = lhs.getDefiningOp(); + mlir::Operation *targetOp; + if (lhsDef) { + if (condFn(lhsDef)) { + targetOp = lhsDef; + } else { + targetOp = findFirstMatchingOperandDef(lhsDef, condFn); + } + if (targetOp) { + return targetOp; + } + } + if (!rhs) { + return nullptr; + } + auto rhsDef = rhs.getDefiningOp(); + if (rhsDef) { + if (condFn(rhsDef)) { + targetOp = rhsDef; + } else { + targetOp = findFirstMatchingOperandDef(rhsDef, condFn); + } + if (targetOp) { + return targetOp; + } + } + return nullptr; +} + +void traverseBackwardUpdateOperandChainIf( + Operation *op, std::function conditionFn, + std::function actionFn, + OpBuilder &builder) { + + if (!op) + return; + + if (conditionFn(op)) { + actionFn(builder, op); + } + + for (Value operand : op->getOperands()) { + // TODO: handle BlockArgument + if (Operation *defOp = operand.getDefiningOp()) { + traverseBackwardUpdateOperandChainIf(defOp, conditionFn, actionFn, + builder); + } + } +} + +// Note: rootOp will also be processed. +void traverseBackwardUpdateOperandChainIf( + Operation *rootOp, std::function conditionFn, + std::function actionFn) { + + OpBuilder builder(rootOp->getContext()); + + traverseBackwardUpdateOperandChainIf(rootOp, conditionFn, actionFn, builder); +} + +void traverseForwardUpdateUserChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + llvm::SmallPtrSet &stopOps) { + + if (!op) { + return; + } + + if (stopFn(op)) { + stopOps.insert(op); + return; + } + + if (conditionFn(op)) { + actionFn(builder, op); + } + + for (auto res : op->getResults()) { + for (auto userOp : res.getUsers()) { + traverseForwardUpdateUserChainIf(userOp, conditionFn, stopFn, actionFn, + builder, stopOps); + } + } +} + +// Note: rootOp will also be processed. +void traverseForwardUpdateUserChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn, + llvm::SmallPtrSet &stopOps) { + + OpBuilder builder(rootOp->getContext()); + + traverseForwardUpdateUserChainIf(rootOp, conditionFn, stopFn, actionFn, + builder, stopOps); +} + +bool isMetaUse(Operation *op) { return op->hasAttr("MetaUse"); } + +bool isMixUse(Operation *op) { return op->hasAttr("MixUse"); } + +IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op) { + auto ty = IndirectLoadInterfaceOpType::Undefined; + if (isMetaUse(op)) { + if (isa(op)) { + ty = IndirectLoadInterfaceOpType::Load; + } else if (isa(op)) { + ty = IndirectLoadInterfaceOpType::Calc; + } + } + return ty; +} + +bool opIsIndirectLoad(Operation *op) { + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Load; +} + +bool opIsIndirectCalc(Operation *op) { + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Calc; +} + +scf::ForOp createNestedLoops( + OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, + ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, + ValueRange initArgs, + function_ref &, ValueRange)> + bodyBuilder) { + + if (currentDim >= totalDims) { + bodyBuilder(builder, loc, ivs, initArgs); + return nullptr; + } + + auto loop = builder.create( + loc, LBs[currentDim], UBs[currentDim], steps[currentDim], initArgs, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange iterArgs) { + ivs.push_back(iv); + auto innerLoop = createNestedLoops(nestedBuilder, nestedLoc, + currentDim + 1, totalDims, LBs, UBs, + steps, ivs, iterArgs, bodyBuilder); + if (innerLoop) { + nestedBuilder.create(loc, innerLoop.getResults()); + } + }); + + return loop; +} + +ModuleOp getModuleOpFromOperation(Operation *op) { + Operation *parent = op; + while (parent != nullptr && !isa(parent)) { + parent = parent->getParentOp(); // 向上查找 + } + return cast(parent); // 如果没找到会抛出异常 +} + +} // namespace triton + +std::optional makeIntAttr(const OpFoldResult &ofr) { + if (isa(ofr) && isa(ofr.get())) + return dyn_cast(ofr.get()).getInt(); + return std::nullopt; +} + +bool hasConstantZero(const OpFoldResult &ofr) { + auto intAttr = makeIntAttr(ofr); + if (intAttr.has_value()) + return !intAttr.value(); + + auto val = dyn_cast(ofr); + assert(val && "Provided ofr must can be cast to Value"); + + auto ConstOp = val.getDefiningOp(); + if (!ConstOp) + return false; + + intAttr = makeIntAttr(ConstOp.getValue()); + return intAttr.has_value() && !intAttr.value(); +} + +Value opFoldResultToIndex(const OpFoldResult &ofr, const Location &loc, + OpBuilder &b) { + if (auto val = dyn_cast(ofr)) { + assert(val.getType().isIndex() && "Provided ofr shoule be type of Index"); + return val; + } + + auto intAttr = makeIntAttr(ofr); + if (intAttr.has_value()) { + return b.create(loc, b.getIndexAttr(intAttr.value())); + } + llvm_unreachable("Unexpected OpFoldResult state"); + return nullptr; +} + +SmallVector opFoldResultToIndex(ArrayRef ofrs, + const Location &loc, OpBuilder &b) { + return llvm::map_to_vector<4>(ofrs, [&](OpFoldResult ofr) -> Value { + return opFoldResultToIndex(ofr, loc, b); + }); +} + +Value createConstIntOp(const Location &loc, OpBuilder &b, int64_t value) { + return b.create(loc, b.getIndexAttr(value)).getResult(); +} + +// TODO: imply these function below +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) { + return lhs; + } + if (!rhsInt && lhsInt && lhsInt.value() == 0) { + return rhs; + } + + if (lhsInt && rhsInt) { + return b.getIndexAttr(lhsInt.value() + rhsInt.value()); + } + + auto lhsValue = dyn_cast(lhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } else { + assert(isa(lhsValue.getType())); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } else { + assert(isa(rhsValue.getType())); + } + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) { + return lhs; + } + + if (lhsInt && rhsInt) { + return b.getIndexAttr(lhsInt.value() - rhsInt.value()); + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + + if (lhsInt) { + if (lhsInt.value() == 0) { + return lhs; + } + if (lhsInt.value() == 1) { + return rhs; + } + } + if (rhsInt) { + if (rhsInt.value() == 0) { + return rhs; + } + if (rhsInt.value() == 1) { + return lhs; + } + } + + if (lhsInt && rhsInt) { + return b.getIndexAttr(lhsInt.value() * rhsInt.value()); + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const Value &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsConstFlag = false; + + auto rhsConstInt = std::numeric_limits::max(); + auto rhsOp = rhs.getDefiningOp(); + if (rhsOp) { + rhsConstFlag = true; + rhsConstInt = dyn_cast(rhsOp.getValue()).getInt(); + } + + if (lhsInt && rhsConstFlag) { + return b.getIndexAttr(lhsInt.value() * rhsConstInt); + } + + if (lhsInt) { + if (lhsInt.value() == 0) { + return lhs; + } + if (lhsInt.value() == 1) { + return rhs; + } + } + if (rhsConstFlag) { + if (rhsConstInt == 0) { + return rhsOp.getResult(); + } + if (rhsConstInt == 1) { + return lhs; + } + } + + if (lhsInt && !rhsConstFlag) { + auto lhsValue = createConstIntOp(loc, b, lhsInt.value()); + return b.create(loc, lhsValue, rhs).getResult(); + } + assert(!lhsInt); + return b.create(loc, lhs.get(), rhs).getResult(); +} + +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + if (lhsInt) { + if (lhsInt.value() == 0) { + return lhs; + } + } + if (rhsInt) { + if (rhsInt.value() == 0) { + emitError(loc) << "cannot div 0!"; + return OpFoldResult(); + } + if (rhsInt.value() == 1) { + return lhs; + } + } + + if (lhsInt && rhsInt) { + return b.getIndexAttr(lhsInt.value() / rhsInt.value()); + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + if (lhsInt && lhsInt.value() == 0) { + return lhs; + } + if (lhsInt && rhsInt) { + return b.getIndexAttr(lhsInt.value() % rhsInt.value()); + } + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + if (lhsInt && rhsInt) { + return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); + } + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = makeIntAttr(lhs); + auto rhsInt = makeIntAttr(rhs); + if (lhsInt && rhsInt) { + return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); + } + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) { + lhsValue = createConstIntOp(loc, b, lhsInt.value()); + } + if (rhsInt) { + rhsValue = createConstIntOp(loc, b, rhsInt.value()); + } + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +LogicalResult +addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, + linalg::ReduceOp reduceOp) { + // To verify whether the operation of the reduceOp is ReduceWithIndex + // TODO: maybe a better way of judging? + auto ctx = reduceOp.getContext(); + Block &body = reduceOp.getCombiner().front(); + auto yieldOp = dyn_cast(body.getTerminator()); + + auto yieldValue = yieldOp.getValues(); + if (yieldValue.size() == 0) { + return failure(); + } + + auto opIter = reduceOp.getBody()->without_terminator().begin(); + auto cmpMaskOp = dyn_cast(*opIter); + const StringRef reduceRef = "reduce_mode"; + if (cmpMaskOp) { + if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OGT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OLT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + auto cmpMaskIOp = dyn_cast(*opIter); + if (cmpMaskIOp) { + if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::sgt) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::slt) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + return success(); +} + +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/tools/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/CMakeLists.txt new file mode 100644 index 000000000..628169551 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton-adapter-opt) diff --git a/third_party/ascend/triton-adapter/triton_adapter.cc b/third_party/ascend/triton-adapter/triton_adapter.cc new file mode 100644 index 000000000..7fa5e82a5 --- /dev/null +++ b/third_party/ascend/triton-adapter/triton_adapter.cc @@ -0,0 +1,6 @@ +#include + +namespace py = pybind11; + +// compilation goes to triton-adapter-opt, do nothing here +void init_triton_triton_adapter(py::module &&m) {} diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp new file mode 100644 index 000000000..9a83aa018 --- /dev/null +++ b/third_party/ascend/triton_ascend.cpp @@ -0,0 +1,11 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#define PY_SSIZE_T_CLEAN +#include +namespace py = pybind11; + +// register huawei passes to triton +void init_triton_huawei(py::module &&m) { + // currently no extra modules needed to plug-in libtriton.so +} diff --git a/third_party/ascend/triton_patch/include/CMakeLists.txt b/third_party/ascend/triton_patch/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/ascend/triton_patch/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt new file mode 100644 index 000000000..0ca0f41c5 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..5e601271e --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..9984e2e01 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,34 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +# file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") +# string(REPLACE "triton_patch" "triton" triton_rel_dir "${patch_rel_dir}") +# set(triton_abs_dir "${CMAKE_SOURCE_DIR}/${triton_rel_dir}") +# message(STATUS "triton_abs_dir: ${triton_abs_dir}") +# message(${triton_abs_dir}) +set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") +message(${triton_abs_dir}) +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..b59bc7c8f --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,137 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee">, + I32EnumAttrCase<"HF32", 3, "hf32">, + ]>{ + let cppNamespace = "::mlir::triton"; +} + +// Type for F8F6F4 kind of floats. +def TT_F8F6F4TypeAttr : I32EnumAttr< + "F8F6F4Type", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1"> + + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..9bca3da18 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1286 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatTensor:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let hasFolder = 1; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + DeclareOpInterfaceMethods, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_TensorOrMemDesc:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorOrMemDesc:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are integer types as they are packed types and we currently + // don't have a representation for those. + TT_IntTensor:$lhs, + TT_IntTensor:$rhs, + TT_FloatTensor:$c, + TT_IntTensor:$lhs_scale, + Optional:$rhs_scale, + TT_F8F6F4TypeAttr:$lhs_type, + TT_F8F6F4TypeAttr:$rhs_type + ); + + let results = (outs TT_FloatTensor:$d); + + // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file + let assemblyFormat = [{ + $lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict + `:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) + }]; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsShape, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histgram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + }]; + + let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$args, + DenseI32ArrayAttr:$isSigned + ); + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor and a message string. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ + MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + }]; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + }]; +} + +def TT_ExperimentalTensormapCreateOp: TT_Op< + "experimental_tensormap_create", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, + ] +> { + let summary = "Create a new TMA descriptor on device"; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_PtrType:$global_address, + Variadic:$box_dim, + Variadic:$global_dim, + Variadic:$global_stride, + Variadic:$element_stride, + ConfinedAttr]>:$elem_type, + ConfinedAttr]>:$interleave_layout, + ConfinedAttr]>:$swizzle_mode, + ConfinedAttr]>:$fill_mode + ); + let extraClassDeclaration = [{ + int32_t getRank() { + return getBoxDim().size(); + } + }]; + let assemblyFormat = [{ + $desc_ptr `,` $global_address `,` + `[` $box_dim `]` `,` + `[` $global_dim `]` `,` + `[` $global_stride `]` `,` + `[` $element_stride `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< + "experimental_tensormap_fenceproxy_acquire", + [MemoryEffects<[MemWrite]>] +> { + let summary = "Acquire fence on a tensormap object"; + let arguments = (ins TT_PtrType:$desc_ptr); + let assemblyFormat = [{ + $desc_ptr attr-dict `:` qualified(type($desc_ptr)) + }]; +} + + +#endif // Triton_OPS diff --git a/third_party/ascend/triton_patch/lib/CMakeLists.txt b/third_party/ascend/triton_patch/lib/CMakeLists.txt new file mode 100644 index 000000000..0ca0f41c5 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..5e601271e --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..3b7c3746a --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(Patched_TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..dc2417712 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,139 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +struct TensorModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..87aca769f --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1092 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + if (auto memDescTy = dyn_cast(argTy)) { + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), + memDescTy.getMutableMemory())); + } else { + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult TransOp::verify() { + // Check that the op's `order` attribute is a permutation of the right length. + auto srcTy = getSrc().getType(); + + ArrayRef order = getOrder(); + if (order.size() != srcTy.getRank()) { + return emitError("order must have the same size as the rank of the " + "operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + state.addAttribute("reverse", builder.getBoolAttr(reverse)); + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + if (!isa(value)) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (!op.getAllowReorder() || op.getEfficientLayout()) + return failure(); + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } + } + + return success(); +} + +//-- FpToFpOp -- +LogicalResult FpToFpOp::verify() { + auto dstType = getType().getElementType(); + auto srcType = getSrc().getType().getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (int i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, 1); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + Value rhs = operands[1]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 1); + assert(isa(operands[0].getType())); + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- ExperimentalTensormapCreateOp -- +LogicalResult ExperimentalTensormapCreateOp::verify() { + auto rank = getBoxDim().size(); + if (getGlobalDim().size() != rank) { + return emitError("Rank mismatch for global dim. Got") + << getGlobalDim().size() << " but expected " << rank; + } + if (getGlobalStride().size() + 1 != rank) { + return emitError("Rank mismatch for global stride. Got") + << getGlobalStride().size() << " but expected " << rank - 1; + } + if (getElementStride().size() != rank) { + return emitError("Rank mismatch for element stride. Got") + << getElementStride().size() << " but expected " << rank; + } + return success(); +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..b43a9b56c --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,239 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto module = op->getParentOfType(); + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + ttg::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..6e41e70a8 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,197 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + return Type(); + + // Parse the element type. + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(encoding)) + return Type(); + } + bool mutableMemory = false; + Attribute memorySpace; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOptionalKeyword(kMutableMemory))) { + if (parser.parseAttribute(memorySpace)) + return Type(); + } else { + mutableMemory = true; + } + } + if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { + if (parser.parseOptionalKeyword(kMutableMemory)) + return Type(); + mutableMemory = true; + } + if (parser.parseGreater()) + return Type(); + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, memorySpace, mutableMemory); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + printer << ">"; +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/ascend/triton_patch/python/src/ir.cc b/third_party/ascend/triton_patch/python/src/ir.cc new file mode 100644 index 000000000..637e15c42 --- /dev/null +++ b/third_party/ascend/triton_patch/python/src/ir.cc @@ -0,0 +1,1771 @@ +#include +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/SourceMgr.h" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .value("HF32", InputPrecision::HF32) + .export_values(); + + py::enum_(m, "F8F6F4TY", py::module_local()) + .value("E4M3", F8F6F4Type::E4M3) + .value("E5M2", F8F6F4Type::E5M2) + .value("E2M3", F8F6F4Type::E2M3) + .value("E3M2", F8F6F4Type::E3M2) + .value("E2M1", F8F6F4Type::E2M1) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + mlir::LLVM::registerInlinerInterface(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Check if the result of tl.advance is used + self.walk([&](AdvanceOp op) { + if (op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value desc_ptr, + std::vector &indices, Type type, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + return self.create( + type, desc_ptr, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value desc_ptr, Value value, + std::vector &indices) -> void { + self.create(desc_ptr, value, + indices); + }) + .def("create_tensormap_create", + [](TritonOpBuilder &self, Value desc_ptr, Value global_address, + std::vector box_dim, std::vector global_dim, + std::vector global_stride, + std::vector element_stride, int32_t elem_type, + int32_t interleave_layout, int32_t swizzle_mode, + int32_t fill_mode) { + self.create( + desc_ptr, global_address, box_dim, global_dim, global_stride, + element_stride, elem_type, interleave_layout, swizzle_mode, + fill_mode); + }) + .def("create_tensormap_fenceproxy_acquire", + [](TritonOpBuilder &self, Value desc_ptr) { + self.create(desc_ptr); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + .def("create_slice", + [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(retTy, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_insert", + [](TritonOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, + F8F6F4Type lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, F8F6F4Type rhs_format, + mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale, + rhs_scale.value_or(Value()), lhs_format, rhs_format); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + if (!funcToDump.empty()) + haveDump = true; + } + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/ascend/triton_patch/python/triton_patch/__init__.py b/third_party/ascend/triton_patch/python/triton_patch/__init__.py new file mode 100644 index 000000000..b2a5aa214 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/__init__.py @@ -0,0 +1,5 @@ +# import triton +# from .compiler.errors import MLIRCompilationError as ascend_MLIRCompilationError +# triton.compiler.errors.MLIRCompilationError = ascend_MLIRCompilationError +# from .language._utils import validate_block_shape as ascend_validate_block_shape +# triton.language._utils.validate_block_shape = ascend_validate_block_shape diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py new file mode 100644 index 000000000..d9cb07d64 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/compiler/code_generator.py @@ -0,0 +1,1303 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from triton import language +from triton._C.libtriton import ir +from triton.language import constexpr, tensor, str_to_ty +from triton.language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value +from triton.runtime.jit import _normalize_ty, get_jit_fn_file_line +# ideally we wouldn't need any runtime component +from triton.runtime import JITFunction +from triton.compiler.errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + raise TypeError(f'Unsupported type {ty}') + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, _value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.semantic.to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i in range(len(arg_names)): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + assert not self.builder.get_insertion_block().has_terminator() + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + self.builder.ret([ + self.builder.create_poison(ty.to_ir(self.builder)) + for ty in self.prototype.ret_types + if self.ret_type is not None + ]) + self.fn.finalize() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + if isinstance(node, ast.AnnAssign): + _names += [self.visit(node.target)] + else: + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = language.semantic.to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + for ty in ir_ret_types: + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + # these are loop-carried values + names.append(name) + ret_types.append(loop_val.type) + init_args.append(live_val) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.semantic.to_tensor(lb, self.builder) + ub = language.semantic.to_tensor(ub, self.builder) + step = language.semantic.to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + names.append(name) + init_args.append(live_val) + yields.append(loop_val) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_unroll_factor is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = {} + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = {"_builder": self.builder} + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, repr(e)) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisibility_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + + new_attrs = attrs.filter_out_constants() + fn_attrs = new_attrs.get_fn_attrs() + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = get_jit_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py new file mode 100644 index 000000000..e368a4b23 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/compiler/compiler.py @@ -0,0 +1,447 @@ +from __future__ import annotations +import hashlib +import json +from triton._C.libtriton import get_cache_invalidating_env_vars, ir +from pathlib import Path +import re +import functools +import os + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + from triton.backends.compiler import AttrsDescriptor + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + if self.constants is None: + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + from triton.compiler.code_generator import ast_to_ttir + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + from triton import __version__ + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + TRITON_PATH = os.path.dirname(TRITON_PATH) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + from triton.backends.compiler import GPUTarget + from triton.runtime.cache import get_cache_manager, get_dump_manager, get_override_manager + from triton.runtime.driver import driver + from .errors import MLIRCompilationError + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') + if (compile_speed_opt): + ttir_path = f"{file_name}.ttir" + if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): + # Already compile once but failed. So directly return + raise Exception("already failed once") + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + use_ir_loc = os.environ.get("USE_IR_LOC", None) + for ext, compile_ir in list(stages.items())[first_stage:]: + try: + next_module = compile_ir(module, metadata) + except Exception as e: + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + raise MLIRCompilationError(stage_name, e.stderr.decode('utf-8')) + ir_filename = f"{file_name}.{ext}" + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + context.disable_multithreading() + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + from triton.backends import backends + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + from triton.tools.disasm import get_sass + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + from triton.backends.compiler import GPUTarget + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + from triton.runtime.errors import OutOfResources + from triton.runtime.driver import driver + if self.module is not None: + return + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + # on NPU, get_device_properties in fact does not use the device param + # but we still need to preserve it because triton defines the API + device = driver.active.get_current_device() + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + from triton.runtime.driver import driver + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py b/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py new file mode 100644 index 000000000..23b6bdbf4 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/compiler/errors.py @@ -0,0 +1,72 @@ +import ast +from typing import Optional +from triton.errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass + + +class MLIRCompilationError(TritonError): + + def __init__(self, stage_name: Optional[str], message: Optional[str] = None): + self.stage_name = stage_name + self.message = f"\n" \ + f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ + f"[{self.stage_name}] encounters error:\n" \ + f"{self.filter_message(message)}" \ + f"{self.format_line_delim('[ERROR][Triton][END]')}" + + def __str__(self): + return self.message + + def filter_message(self, message): + # Content starting from "Stack dump without symbol names" means nothing to the users + return message.split("Stack dump without symbol names")[0] + + def format_line_delim(self, keyword): + return f"///------------------{keyword}------------------\n" diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/__init__.py b/third_party/ascend/triton_patch/python/triton_patch/language/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py b/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py new file mode 100644 index 000000000..f83b18855 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/language/_utils.py @@ -0,0 +1,15 @@ +from typing import List + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/core.py b/third_party/ascend/triton_patch/python/triton_patch/language/core.py new file mode 100644 index 000000000..a5cdf3e43 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/language/core.py @@ -0,0 +1,229 @@ +import os +from typing import List +from triton.language.core import _tensor_member_fn, builtin, _constexpr_to_value, tensor, constexpr +from triton.language.core import dtype as real_dtype +from triton.language import semantic as real_semantic +from triton._C.libtriton import ir +from triton.language.core import float32 +# from triton.language.core import _unwrap_if_constexpr, _unwrap_shape +from . import semantic +# from ._utils import validate_block_shape + +# class dtype(real_dtype): + +# def to_ir(self, builder: ir.builder) -> ir.type: +# if self.name in ("uint8", "uint16", "uint32", "uint64"): +# raise ValueError(f"type {self} not supported in this architecture for now.") + +# if self.name.startswith("fp8"): +# if self.name not in builder.options.supported_fp8_dtypes: +# raise ValueError(f'type {self} not supported in this architecture. ' +# f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') +# if self.name in builder.options.deprecated_fp8_dtypes: +# warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release") + +# if self.name == 'void': +# return builder.get_void_ty() +# elif self.name == 'int1': +# return builder.get_int1_ty() +# elif self.name in ('int8', 'uint8'): +# return builder.get_int8_ty() +# elif self.name in ('int16', 'uint16'): +# return builder.get_int16_ty() +# elif self.name in ('int32', 'uint32'): +# return builder.get_int32_ty() +# elif self.name in ('int64', 'uint64'): +# return builder.get_int64_ty() +# elif self.name == 'fp8e5': +# return builder.get_fp8e5_ty() +# elif self.name == 'fp8e5b16': +# return builder.get_fp8e5b16_ty() +# elif self.name == 'fp8e4nv': +# return builder.get_fp8e4nv_ty() +# elif self.name == 'fp8e4b8': +# return builder.get_fp8e4b8_ty() +# elif self.name == 'fp8e4b15': +# return builder.get_fp8e4b15_ty() +# elif self.name == 'fp16': +# return builder.get_half_ty() +# elif self.name == 'bf16': +# return builder.get_bf16_ty() +# elif self.name == 'fp32': +# return builder.get_float_ty() +# elif self.name == 'fp64': +# return builder.get_double_ty() +# raise ValueError(f'fail to convert {self} to ir type') + +# class pointer_type(dtype): + +# def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): +# element_ty = _unwrap_if_constexpr(element_ty) +# if not isinstance(element_ty, dtype): +# raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') +# self.element_ty = element_ty +# self.address_space = address_space +# self.const = const +# self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' + +# def to_ir(self, builder: ir.builder): +# return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) + +# def __str__(self): +# return self.name + +# def __repr__(self): +# return self.__str__() + +# def is_ptr(self): +# return True + +# def is_const(self): +# return self.const + +# def __eq__(self, other: pointer_type) -> bool: +# if not isinstance(other, pointer_type): +# return False +# return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const + +# def __ne__(self, other: pointer_type) -> bool: +# return not self.__eq__(other) + +# @property +# def scalar(self): +# return self + +# class block_type(dtype): + +# def __init__(self, element_ty: dtype, shape: List): +# self.element_ty = element_ty + +# # Note that block_type's shape is a list of int +# # while tensor's shape is a list of constexpr. + +# # shape can be empty ([]) when an input is a 0D tensor. +# self.shape = _unwrap_shape(shape) +# if not self.shape: +# raise TypeError('0d block_type is forbidden') + +# self.numel = validate_block_shape(self.shape) +# self.name = f'<{self.shape}, {self.element_ty}>' + +# def to_ir(self, builder: ir.builder) -> ir.block_type: +# return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + +# def __str__(self): +# return self.name + +# def __repr__(self): +# return self.__str__() + +# def is_block(self): +# return True + +# def get_block_shapes(self) -> List[int]: +# return self.shape + +# def __eq__(self, other: block_type) -> bool: +# if not isinstance(other, block_type): +# return False +# return self.element_ty == other.element_ty and self.shape == other.shape + +# def __ne__(self, other: block_type) -> bool: +# return not self.__eq__(other) + +# @property +# def scalar(self): +# return self.element_ty + +# class function_type(dtype): + +# def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: +# self.ret_types = ret_types +# self.param_types = param_types + +# def __str__(self): +# return f'fn ({self.param_types}) -> {self.ret_types}' + +# def to_ir(self, builder: ir.builder): +# ir_param_types = [ty.to_ir(builder) for ty in self.param_types] +# ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] +# return builder.get_function_ty(ir_param_types, ret_types) + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + assert not allow_tf32, "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + else: + assert (input_precision not in [ + "tf32", "tf32x3" + ]), "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + + +@_tensor_member_fn +@builtin +def insert(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert (len(ful.shape) > 0) + assert (len(ful.shape) == len(sub.shape)) + new_offsets = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + out = semantic.insert(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@_tensor_member_fn +@builtin +def subview(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert (len(ful.shape) > 0) + new_offsets = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + sub = semantic.subview(ful, new_offsets, sizes, strides, _builder) + return sub diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/math.py b/third_party/ascend/triton_patch/python/triton_patch/language/math.py new file mode 100644 index 000000000..d381508fd --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/language/math.py @@ -0,0 +1,140 @@ +from triton.language import core +from triton.language.math import _check_dtype, _add_math_1arg_docstr, _add_math_2arg_docstr +from triton.language import semantic + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def tanh(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_tanh(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py b/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py new file mode 100644 index 000000000..1d2db8c83 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/language/semantic.py @@ -0,0 +1,270 @@ +from typing import List, Optional, Union +import numbers +import triton.language as tl +from triton._C.libtriton import ir +from triton.language.semantic import wrap_tensor, _str_to_rounding_mode, not_equal, _str_to_dot_input_precision, binary_op_type_checking_impl, integer_promote_impl + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + # if (range & (range - 1)) != 0: + # raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): + if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): + raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + + if max_num_imprecise_acc is not None: + tl.static_print("max_num_imprecise_acc is not supported on Ascend yet. Thus it is ignored.") + max_num_imprecise_acc = 0 + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# Use Union instead of |. Becase python 3.9 does not support |. +# It will reports error: TypeError: unsupported operand type(s) for |: 'type' and 'ABCMeta' +def floordiv(input: Union[tl.tensor, numbers.Number], other: Union[tl.tensor, numbers.Number], + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_bool() or other_scalar_ty.is_bool(): + raise TypeError(f"unexpected type {input_scalar_ty}") + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + + +def insert(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + + +def subview(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) diff --git a/third_party/ascend/triton_patch/python/triton_patch/language/standard.py b/third_party/ascend/triton_patch/python/triton_patch/language/standard.py new file mode 100644 index 000000000..83e318119 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/language/standard.py @@ -0,0 +1,18 @@ +from triton.language import core +from triton.runtime.jit import jit + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_print("tl.flip is unsupported for now. Use libdevice.flip instead.") + core.static_assert(False) + return x diff --git a/third_party/ascend/triton_patch/python/triton_patch/patch.py b/third_party/ascend/triton_patch/python/triton_patch/patch.py new file mode 100644 index 000000000..4eb332414 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/patch.py @@ -0,0 +1,11 @@ +import sys +import os +from importlib.util import spec_from_file_location, module_from_spec + +triton_root = os.path.dirname(__file__) +if triton_root not in sys.path: + sys.path.append(triton_root) +triton_patch_init_path = os.path.join(triton_root, "triton_patch/__init__.py") +spec = spec_from_file_location("triton_patch", triton_patch_init_path) +module = module_from_spec(spec) +spec.loader.exec_module(module) diff --git a/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py b/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py new file mode 100644 index 000000000..2c41bcc46 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/runtime/autotuner.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +from .jit import KernelInterface + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + ): + from triton.runtime.driver import driver + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [ + Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) + ] + + else: + self.configs = configs + self.keys = key + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from triton.testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench + + def _bench(self, *args, config, **meta): + from triton.runtime.errors import OutOfResources + from triton.compiler.errors import CompileTimeAssertionFailure + from ..compiler.errors import MLIRCompilationError + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py b/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py new file mode 100644 index 000000000..1b79635e3 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/runtime/jit.py @@ -0,0 +1,952 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) is not ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + self._update_hash(val) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v, align): + + if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if align and (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif hasattr(arg, "tma_desc_cpu_ptr"): + return "nvTmaDesc" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + if not kp.do_not_specialize_on_alignment: + specialisations.append('compute_spec_key(%s, align=True)' % name) + else: + specialisations.append('compute_spec_key(%s, align=False)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = backend.compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + is_warmup, + before, + ): + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self, backend): + """ + Precompute as much as possible. + """ + from ..compiler.compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params, backend) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + from triton.runtime.driver import driver + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + + # parse options + from ..compiler.compiler import make_backend + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + target = driver.active.get_current_target() + backend = make_backend(target) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder(backend) + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or (p.num in constant_params) or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler.compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor + import json + import triton.language as tl + from triton.runtime.driver import driver + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from triton.runtime.interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/third_party/ascend/triton_patch/python/triton_patch/testing.py b/third_party/ascend/triton_patch/python/triton_patch/testing.py new file mode 100644 index 000000000..9d68310d9 --- /dev/null +++ b/third_party/ascend/triton_patch/python/triton_patch/testing.py @@ -0,0 +1,570 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def _summarize_statistics(times, quantiles, return_mode): + import torch + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(torch, return_mode)(times).item() + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + import torch + from triton import runtime + + enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() in ('npu') + if torch.npu.is_available() and enable_bench_npu: + return do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) + + di = runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + +def collect_files(base_dir): + import pandas as pd + for root, dirs, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] + if not triton_rows.empty: + return triton_rows['Avg Time(us)'].values[0] + return float('inf') + return float('inf') + + +def do_bench_npu(fn, warmup=5, active=30): + import torch + import torch_npu + import hashlib + from datetime import datetime + + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + skip_first = 1 + wait = 0 + repeat = 1 + total = skip_first + (wait + warmup + active) * repeat + md5_hash = hashlib.md5(datetime.now().strftime('%Y-%m-%d').encode('utf-8')).hexdigest() + torch_path = "./profile_result/" + md5_hash + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), record_shapes=False, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + + for i in range(total): + fn() + prof.step() + stream.synchronize() + + time = collect_files(torch_path) + + import shutil + import os + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + # TODO: use logging + # print("avg time = ", time, type(time)) + return time + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from triton.runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + import triton.language as tl + from triton.runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from triton.runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/ascend/utils.py b/third_party/ascend/utils.py new file mode 100644 index 000000000..1b92e492f --- /dev/null +++ b/third_party/ascend/utils.py @@ -0,0 +1,152 @@ +import os +import shutil + + +def insert_at_file_start(filepath, import_lines): + import tempfile + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_file.write(import_lines + '\n\n') + with open(filepath, 'r') as original_file: + tmp_file.write(original_file.read()) + backup_path = filepath + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def append_at_file_end(filepath, import_lines): + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with open(filepath, 'a', encoding='utf-8') as f: + f.write('\n' + import_lines) + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def post_install(): + import site + install_dir = site.getsitepackages() + install_dir = os.path.join(install_dir, "triton") + init_path = os.path.join(install_dir, "__init__.py") + patched_content = f""" +import sys +from .triton_patch.language import _utils as ascend_utils +sys.modules['triton.language._utils'] = ascend_utils +from .triton_patch.compiler import compiler as ascend_compiler +sys.modules['triton.compiler.compiler'] = ascend_compiler +from .triton_patch.compiler import code_generator as ascend_code_generator +sys.modules['triton.compiler.code_generator'] = ascend_code_generator +from .triton_patch.compiler import errors as ascend_errors +sys.modules['triton.compiler.errors'] = ascend_errors +from .triton_patch.runtime import autotuner as ascend_autotuner +sys.modules['triton.runtime.autotuner'] = ascend_autotuner +from .triton_patch import testing as ascend_testing +sys.modules['triton.testing'] = ascend_testing +""" + insert_at_file_start(init_path, patched_content) + + content_to_append = f""" +from .triton_patch.language.core import dot, gather, insert, subview +from .triton_patch.language.standard import flip +from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil +from . import language + +language.dot = dot +language.flip = flip +language.gather = gather +language.insert = insert +language.subview = subview + +# from .triton_patch.language.core import dtype, pointer_type, block_type, function_type +# language.core.dtype = dtype +# language.core.pointer_type = pointer_type +# language.core.block_type = block_type +# language.core.function_type = function_type + +from .triton_patch.language.semantic import arange, floordiv +language.semantic.arange = arange +language.semantic.floordiv = floordiv + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +""" + append_at_file_end(init_path, content_to_append) + + +def get_ascend_patch_packages(backends): + packages = [] + # packages += get_language_extra_packages() + packages += [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_ascend_patch_package_dir(backends): + package_dir = {} + # language_extra_list = get_language_extra_packages() + # for extra_full in language_extra_list: + # extra_name = extra_full.replace("triton/language/extra/", "") + # package_dir[extra_full] = f"{triton_root_rel_dir}/language/extra/{extra_name}" + # + triton_patch_root_rel_dir = "triton_patch/python/triton_patch" + package_dir["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dir["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dir From 265c0c0605984de3ad3af5296372c09f147d8b11 Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Fri, 23 May 2025 09:11:45 +0000 Subject: [PATCH 02/14] [add triton-adapter-opt] --- .../tools/triton-adapter-opt/CMakeLists.txt | 18 +++++++++++++ .../triton-adapter-opt/triton-adapter-opt.cpp | 25 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100755 third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt create mode 100644 third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt new file mode 100755 index 000000000..696b46f41 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt @@ -0,0 +1,18 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + +add_llvm_executable(triton-adapter-opt triton-adapter-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-adapter-opt) +target_link_libraries(triton-adapter-opt PRIVATE TritonToLinalg + TritonTransforms + ${dialect_libs} + ${conversion_libs} + TritonGPUTransforms + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-adapter-opt) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp new file mode 100644 index 000000000..ba9c185e2 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp @@ -0,0 +1,25 @@ +#include "../../include/TritonToLinalg/Passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + mlir::triton::registerTritonToLinalgPass(); + + registry.insert< + mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, + mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, + mlir::linalg::LinalgDialect, mlir::func::FuncDialect, + mlir::tensor::TensorDialect, mlir::memref::MemRefDialect, + mlir::bufferization::BufferizationDialect, mlir::gpu::GPUDialect>(); + + return mlir::asMainReturnCode( + mlir::MlirOptMain(argc, argv, "Triton-Adapter test driver\n", registry)); +} From 95ff6ce847d20e785a49712937acdbfc87d89112 Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Fri, 23 May 2025 09:42:25 +0000 Subject: [PATCH 03/14] update CMakeLists.txt --- CMakeLists.txt | 331 +----------------- .../tools/triton-adapter-opt/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 331 deletions(-) mode change 100755 => 100644 third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 031cfa206..8421086a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -155,7 +155,7 @@ function(add_triton_object name) target_link_libraries(${name} PUBLIC ${patched_link_libs}) endif() else() - add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + #add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) if(ARG_DEPENDS) add_dependencies(${name} ${ARG_DEPENDS}) endif() @@ -479,332 +479,3 @@ endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) endif() - - -# cmake_minimum_required(VERSION 3.18) - -# if(POLICY CMP0116) -# # Introduced in cmake 3.20 -# # https://cmake.org/cmake/help/latest/policy/CMP0116.html -# cmake_policy(SET CMP0116 OLD) -# endif() - -# include(ExternalProject) - -# set(CMAKE_CXX_STANDARD 17) - -# set(CMAKE_INCLUDE_CURRENT_DIR ON) - -# project(triton) -# include(CTest) - -# set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") -# set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") -# set(PATCHED_TRITON_LIBRARIES -# "TritonIR" -# ) -# set(PATCHED_TRITON_DEPENDS -# "TritonTableGen" -# ) - -# if(NOT WIN32) -# list(APPEND CMAKE_MODULE_PATH "${TRITON_ROOT_DIR}/cmake") -# endif() - -# # Options -# option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" OFF) -# option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON) -# option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" OFF) -# option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" OFF) -# set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") - -# # Ensure Python3 vars are set correctly -# # used conditionally in this file and by lit tests - -# # Customized release build type with assertions: TritonRelBuildWithAsserts -# set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -# set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -# set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") -# set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") - -# # Default build type -# if(NOT CMAKE_BUILD_TYPE) -# message(STATUS "Default build type: Release") -# set(CMAKE_BUILD_TYPE "Release") -# endif() - -# if(NOT WIN32) -# find_library(TERMINFO_LIBRARY tinfo) -# endif() - -# # Compiler flags -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") - -# # ######### -# # LLVM -# # ######### -# if(NOT MLIR_DIR) -# set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir) -# endif() - -# # MLIR -# find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR}) - -# list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -# list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") - -# include(TableGen) # required by AddMLIR -# include(AddLLVM) -# include(AddMLIR) - -# # Utilities -# function(add_triton_object name) -# cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN}) -# add_library(${name} OBJECT) -# target_sources(${name} -# PRIVATE ${ARG_UNPARSED_ARGUMENTS} -# INTERFACE $ -# ) - -# set(patched_depends "") -# foreach(dep ${ARG_DEPENDS}) -# list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) -# if(index GREATER_EQUAL 0) -# list(APPEND patched_depends "Patched_${dep}") -# message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}") -# else() -# list(APPEND patched_depends ${dep}) -# endif() -# endforeach() -# if(patched_depends) -# add_dependencies(${name} ${patched_depends}) -# endif() - -# set(patched_link_libs "") -# foreach(lib ${ARG_LINK_LIBS}) -# list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index) -# if(index GREATER_EQUAL 0) -# list(APPEND patched_link_libs "Patched_${lib}") -# message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}") -# else() -# list(APPEND patched_link_libs ${lib}) -# endif() -# endforeach() -# if(patched_link_libs) -# target_link_libraries(${name} PUBLIC ${patched_link_libs}) -# endif() - -# endfunction(add_triton_object) - -# set_property(GLOBAL PROPERTY TRITON_LIBS "") -# function(add_triton_library name) -# list(FIND PATCHED_TRITON_LIBRARIES "${name}" index) -# if(index GREATER_EQUAL 0) -# message(STATUS "Adding Patched_${name} as a lib, instead of ${name}") -# return() -# endif() -# set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) -# add_triton_object(${name} ${ARGN}) -# llvm_update_compile_flags(${name}) -# endfunction() - -# set_property(GLOBAL PROPERTY TRITON_PLUGINS "") -# function(add_triton_plugin name) -# set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) -# add_triton_object(${name} ${ARGN}) -# endfunction() - -# function(remove_component_from_property property_name component_to_remove) -# get_property(prop_value GLOBAL PROPERTY ${property_name}) -# string(REPLACE ";" ";" prop_list "${prop_value}") -# list(REMOVE_ITEM prop_list "${component_to_remove}") -# string(REPLACE ";" ";" modified_prop "${prop_list}") -# set_property(GLOBAL PROPERTY ${property_name} "${modified_prop}") -# message(STATUS "Removed '${component_to_remove}' from ${property_name}") -# endfunction() - -# # Disable warnings that show up in external code (gtest;pybind11) -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") - - -# include_directories(${TRITON_ROOT_DIR}) -# include_directories(${MLIR_INCLUDE_DIRS}) -# include_directories(${LLVM_INCLUDE_DIRS}) -# include_directories(${PATCHED_TRITON_ROOT_DIR}/include) -# include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files -# include_directories(${TRITON_ROOT_DIR}/include) -# include_directories(${PROJECT_BINARY_DIR}/triton/include) # Tablegen'd files -# include_directories(${PROJECT_SOURCE_DIR}/include) -# include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files -# # link_directories(${LLVM_LIBRARY_DIR}) -# add_subdirectory(${TRITON_ROOT_DIR}/include) -# add_subdirectory(${TRITON_ROOT_DIR}/lib) -# # remove_component_from_property(TRITON_LIBS "TritonIR") -# add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include) -# add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) - -# # find_package(PythonLibs REQUIRED) -# set(TRITON_SOURCE_DIR "${TRITON_ROOT_DIR}") -# set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") - -# # TODO: Figure out which target is sufficient to fix errors; triton is -# # apparently not enough. Currently set linking libstdc++fs for all targets -# # to support some old version GCC compilers like 8.3.0. -# if (NOT WIN32 AND NOT APPLE) -# link_libraries(stdc++fs) -# endif() - - -# # ----- - -# # ------ -# if(TRITON_BUILD_PYTHON_MODULE) -# message(STATUS "Adding Python module") -# set(PYTHON_SRC_PATH ${TRITON_ROOT_DIR}/python/src) -# set(PATCHED_PYTHON_SRC_PATH ${PATCHED_TRITON_ROOT_DIR}/python/src) -# include_directories(${PYTHON_SRC_PATH}) - -# if(PYTHON_INCLUDE_DIRS) -# # We have PYTHON_INCLUDE_DIRS set--this is what we expect when building -# # using pip install. -# include_directories(${PYTHON_INCLUDE_DIRS}) -# include_directories(${PYBIND11_INCLUDE_DIR}) -# else() -# # Otherwise, we might be building from top CMakeLists.txt directly. -# # Try to find Python and pybind11 packages. -# find_package(Python3 REQUIRED COMPONENTS Development Interpreter) -# find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") -# include_directories(${Python3_INCLUDE_DIRS}) -# include_directories(${pybind11_INCLUDE_DIR}) -# link_directories(${Python3_LIBRARY_DIRS}) -# link_libraries(${Python3_LIBRARIES}) -# add_link_options(${Python3_LINK_OPTIONS}) -# endif() - -# if (DEFINED TRITON_PLUGIN_DIRS) -# foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS}) -# # Read the plugin name under dir/backend/name.conf -# cmake_path(APPEND PLUGIN_DIR "backend" "name.conf" OUTPUT_VARIABLE PLUGIN_NAME_PATH) -# file(READ ${PLUGIN_NAME_PATH} PLUGIN_NAME) -# string(STRIP ${PLUGIN_NAME} PLUGIN_NAME) - -# list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME}) - -# # Include the plugin as part of the build, placing the build output under -# # ${TRITON_BINARY_DIR}/third_party/${PLUGIN_NAME} -# # cmake_path(APPEND TRITON_BINARY_DIR "third_party" ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT) -# message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}") -# add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT}) -# endforeach() -# endif() - -# foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) -# add_subdirectory(third_party/${CODEGEN_BACKEND}) -# endforeach() - -# get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) -# get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) -# set(TRITON_LIBRARIES -# ${triton_libs} -# ${triton_plugins} - -# # mlir -# MLIRAMDGPUDialect -# MLIRNVVMDialect -# MLIRNVVMToLLVMIRTranslation -# MLIRGPUToNVVMTransforms -# MLIRGPUToGPURuntimeTransforms -# MLIRGPUTransforms -# MLIRIR -# MLIRControlFlowToLLVM -# MLIRBytecodeWriter -# MLIRPass -# MLIRTransforms -# MLIRLLVMDialect -# MLIRSupport -# MLIRTargetLLVMIRExport -# MLIRMathToLLVM -# MLIRROCDLToLLVMIRTranslation -# MLIRGPUDialect -# MLIRSCFToControlFlow -# MLIRIndexToLLVM -# MLIRGPUToROCDLTransforms -# MLIRUBToLLVM - -# # LLVM -# LLVMPasses -# LLVMNVPTXCodeGen -# # LLVMNVPTXAsmPrinter -# LLVMAMDGPUCodeGen -# LLVMAMDGPUAsmParser - -# ) -# if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 -# CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR # macOS arm64 -# CMAKE_OSX_ARCHITECTURES MATCHES "arm64") # also macOS arm64 -# list(APPEND TRITON_LIBRARIES -# LLVMAArch64CodeGen -# LLVMAArch64AsmParser -# ) -# elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") -# list(APPEND TRITON_LIBRARIES -# LLVMX86CodeGen -# LLVMX86AsmParser -# ) -# elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le") -# list(APPEND TRITON_LIBRARIES -# LLVMPowerPCAsmParser -# LLVMPowerPCCodeGen -# ) -# else() -# message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.") -# endif() - -# # Define triton library -# string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS}) - -# if (DEFINED TRITON_PLUGIN_NAMES) -# string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES}) -# endif() - -# message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}") - -# set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") -# add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) -# add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc -# ${PATCHED_PYTHON_SRC_PATH}/ir.cc -# ${PYTHON_SRC_PATH}/passes.cc -# ${PYTHON_SRC_PATH}/interpreter.cc -# ${PYTHON_SRC_PATH}/llvm.cc) -# # Link triton with its dependencies -# target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) -# if(WIN32) -# target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) -# else() -# target_link_libraries(triton PRIVATE z) -# endif() -# target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) -# endif() - -# if (UNIX AND NOT APPLE) -# set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") -# endif() - -# if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) -# set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") - -# # Check if the platform is MacOS -# if(APPLE) -# set(PYTHON_LDFLAGS "-undefined dynamic_lookup") -# endif() - -# target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) -# endif() - -# if(NOT TRITON_BUILD_PYTHON_MODULE) -# foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) -# add_subdirectory(third_party/${CODEGEN_BACKEND}) -# endforeach() -# endif() - -# add_subdirectory(${TRITON_ROOT_DIR}/third_party/f2reduce) diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt old mode 100755 new mode 100644 index 696b46f41..37fea14db --- a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt +++ b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt @@ -15,4 +15,4 @@ target_link_libraries(triton-adapter-opt PRIVATE TritonToLinalg MLIRTransforms ) -mlir_check_all_link_libraries(triton-adapter-opt) \ No newline at end of file +mlir_check_all_link_libraries(triton-adapter-opt) From 0ca4a0d3344b922108f3e31efe0f2a4499e74c82 Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Mon, 26 May 2025 08:00:20 +0000 Subject: [PATCH 04/14] [Bug fix] fix import bug --- CMakeLists.txt | 7 ++++--- third_party/ascend/CMakeLists.txt | 2 +- third_party/ascend/backend/compiler.py | 2 +- third_party/ascend/backend/driver.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8421086a6..ae52b3a85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,7 +57,7 @@ endif() # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) -option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) +option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON) if(FLAGTREE_BACKEND) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" OFF) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" OFF) @@ -66,7 +66,6 @@ else() option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) endif() set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") - # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests @@ -423,6 +422,7 @@ if(TRITON_BUILD_PYTHON_MODULE) set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) + if(FLAGTREE_BACKEND STREQUAL "cambricon") add_library(triton SHARED) else() @@ -431,7 +431,8 @@ if(TRITON_BUILD_PYTHON_MODULE) else() set(IR_SRC ${PYTHON_SRC_PATH}/ir.cc) endif() - add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + add_library(triton SHARED + ${PYTHON_SRC_PATH}/main.cc ${IR_SRC} ${PYTHON_SRC_PATH}/passes.cc ${PYTHON_SRC_PATH}/interpreter.cc diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 3c1ff4337..3321942c7 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -7,6 +7,6 @@ add_custom_target(COPY_TRITON_ADAPTER_OPT) add_custom_command(TARGET COPY_TRITON_ADAPTER_OPT POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ - ${TRITON_ROOT_DIR}/python/triton/backends/huawei/triton-adapter-opt + ${TRITON_ROOT_DIR}/python/triton/backends/ascend/triton-adapter-opt DEPENDS triton-adapter-opt) add_dependencies(TritonHUAWEI COPY_TRITON_ADAPTER_OPT) diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 226f0e78f..448a5ea3d 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -14,7 +14,7 @@ import ctypes from typing import Optional -from triton.backends.huawei.utils import downgrade_llir, _get_llvm_path, _get_mlir_path, _get_triton_adapter_opt_path, \ +from triton.backends.ascend.utils import downgrade_llir, _get_llvm_path, _get_mlir_path, _get_triton_adapter_opt_path, \ _get_kernel_target, _get_npucompiler_path, _is_ascend_sanitizer_enabled diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 5cb59d411..30b807c7a 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -9,7 +9,7 @@ from triton.runtime.cache import get_cache_manager, get_dump_manager from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -from triton.backends.huawei.utils import _build_npu_ext, _check_cxx11_abi +from triton.backends.ascend.utils import _build_npu_ext, _check_cxx11_abi class NPUUtils(object): From 19d7852aca277332706b6bd00e0841f4cbefdc0e Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Mon, 26 May 2025 08:12:51 +0000 Subject: [PATCH 05/14] update --- third_party/ascend/triton_ascend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp index 9a83aa018..bb3cc70bd 100644 --- a/third_party/ascend/triton_ascend.cpp +++ b/third_party/ascend/triton_ascend.cpp @@ -6,6 +6,6 @@ namespace py = pybind11; // register huawei passes to triton -void init_triton_huawei(py::module &&m) { +void init_triton_ascend(py::module &&m) { // currently no extra modules needed to plug-in libtriton.so } From 89e93529365d8842084efd669817739782b5b821 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Mon, 26 May 2025 16:22:17 +0800 Subject: [PATCH 06/14] Update README.md --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 8662caebc..e21a9d8a8 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,17 @@ export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# ascend +TRITON_BUILD_WITH_CLANG_LLD=true +LLVM_SYSPATH=yourpath/llvm-install +TRITON_BUILD_PROTON=OFF +TRITON_APPEND_CMAKE_ARGS="-DTRITON_BUILD_UT=OFF" +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=ascend +python3 -m pip install . --no-build-isolation -v +``` + To build with default backends (nvidia, amd, triton_shared): ```shell # manually download LLVM From 0812ea8f126470e8af994b5e119a2d18023d89b9 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Mon, 26 May 2025 16:23:27 +0800 Subject: [PATCH 07/14] Update README.md --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e21a9d8a8..447d0f7e5 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,8 @@ python3 -m pip install . --no-build-isolation -v ```shell # ascend -TRITON_BUILD_WITH_CLANG_LLD=true -LLVM_SYSPATH=yourpath/llvm-install -TRITON_BUILD_PROTON=OFF -TRITON_APPEND_CMAKE_ARGS="-DTRITON_BUILD_UT=OFF" +export TRITON_BUILD_WITH_CLANG_LLD=true +export LLVM_SYSPATH=yourpath/llvm-install cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=ascend python3 -m pip install . --no-build-isolation -v From 58a489804cb12e4f79c622c859539ae72de29a56 Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Mon, 26 May 2025 09:06:49 +0000 Subject: [PATCH 08/14] update --- third_party/ascend/backend/compiler.py | 2 +- third_party/ascend/backend/cpu_driver.py | 2 +- third_party/ascend/backend/driver.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 448a5ea3d..dc13b8dad 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -271,7 +271,7 @@ def parse_options(self, opts) -> Any: return options def pack_metadata(self, metadata): - from triton.backends.huawei.utils import TRITON_PROFILER_REGISTERED + from triton.backends.ascend.utils import TRITON_PROFILER_REGISTERED # collect necessary metadata to launch kernels # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. # Get this name as the kernel_name to CANN runtime. diff --git a/third_party/ascend/backend/cpu_driver.py b/third_party/ascend/backend/cpu_driver.py index 6bee24282..ec86ec01c 100644 --- a/third_party/ascend/backend/cpu_driver.py +++ b/third_party/ascend/backend/cpu_driver.py @@ -5,7 +5,7 @@ import sysconfig import subprocess import importlib -from triton.backends.huawei.utils import _get_llvm_path +from triton.backends.ascend.utils import _get_llvm_path # TODO: temporarily fake CPUUtils class diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 30b807c7a..f00a1e3f2 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -86,7 +86,7 @@ def __init__(self, src, metadata): def __call__(self, *args, **kwargs): profiler_registered = self.launch(*args, **kwargs) import triton - triton.backends.huawei.utils.TRITON_PROFILER_REGISTERED = True if profiler_registered == 1 else False + triton.backends.ascend.utils.TRITON_PROFILER_REGISTERED = True if profiler_registered == 1 else False class NPUDriver(DriverBase): @@ -100,7 +100,7 @@ def __init__(self): def is_active(cls): def test_npucompiler(): - from triton.backends.huawei.utils import _get_bisheng_path + from triton.backends.ascend.utils import _get_bisheng_path npucompiler = _get_bisheng_path() targets = subprocess.check_output([npucompiler, "-print-targets"]).decode().strip().split() return "hiipu64" in targets From 21a64eb0459106d8106c6aeee895411d269c73e3 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Wed, 28 May 2025 11:06:50 +0800 Subject: [PATCH 09/14] [BUILD] Fix build ascend --- CMakeLists.txt | 3 +++ python/setup_helper.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ae52b3a85..41aeea456 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,9 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND}) +elseif(FLAGTREE_BACKEND STREQUAL "ascend") + set(CMAKE_C_COMPILER clang) + set(CMAKE_CXX_COMPILER clang++) endif() set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") if(FLAGTREE_PLUGIN) diff --git a/python/setup_helper.py b/python/setup_helper.py index 106f0ce56..187da8277 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -9,6 +9,7 @@ from pathlib import Path import hashlib from dataclasses import dataclass +from distutils.sysconfig import get_python_lib use_triton_shared = False necessary_third_party = ["triton_shared"] @@ -257,7 +258,11 @@ class CommonUtils: @staticmethod def unlink(): cur_path = os.path.dirname(__file__) - backends_dir_path = Path(cur_path) / "triton" / "backends" + if "editable_wheel" in sys.argv: + installation_dir = cur_path + else: + installation_dir = get_python_lib() + backends_dir_path = Path(installation_dir) / "triton" / "backends" if not os.path.exists(backends_dir_path): return for name in os.listdir(backends_dir_path): From 9242c6b687d581f2b7f59f4cda99ef5882b2be59 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Wed, 28 May 2025 11:12:30 +0800 Subject: [PATCH 10/14] [DOC] Fix build ascend --- README.md | 11 ++++++++--- README_cn.md | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 447d0f7e5..9a062e590 100644 --- a/README.md +++ b/README.md @@ -51,12 +51,17 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` - ```shell # ascend -export TRITON_BUILD_WITH_CLANG_LLD=true -export LLVM_SYSPATH=yourpath/llvm-install +# manually download LLVM +cd ${YOUR_LLVM_DOWNLOAD_DIR} +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz +tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz cd ${YOUR_CODE_DIR}/flagtree/python +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64 +export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib +export LLVM_SYSPATH=$LLVM_BUILD_DIR export FLAGTREE_BACKEND=ascend python3 -m pip install . --no-build-isolation -v ``` diff --git a/README_cn.md b/README_cn.md index d90aff183..88c76cf15 100644 --- a/README_cn.md +++ b/README_cn.md @@ -51,6 +51,20 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# ascend +# manually download LLVM +cd ${YOUR_LLVM_DOWNLOAD_DIR} +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz +tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64 +export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib +export LLVM_SYSPATH=$LLVM_BUILD_DIR +export FLAGTREE_BACKEND=ascend +python3 -m pip install . --no-build-isolation -v +``` 使用默认的编译命令,可以编译安装 nvidia、amd、triton_shared 后端: ```shell From 7f493c24b152c502263d4c7bd1f1da41447e921e Mon Sep 17 00:00:00 2001 From: i3wanna2 <2535184404@qq.com> Date: Wed, 28 May 2025 03:22:16 +0000 Subject: [PATCH 11/14] [Bug fix] fix product so-package in editable mode --- python/setup_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup_helper.py b/python/setup_helper.py index 187da8277..db4e24e2f 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -359,7 +359,7 @@ def handle_flagtree_backend(): if flagtree_backend: print(f"flagtree_backend is {flagtree_backend}") extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv: + if "editable_wheel" in sys.argv and flagtree_backend != "ascend": ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" if use_triton_shared and not flagtree_backend: default_backends.append("triton_shared") From 663db8ee39b5465703fe53a7db52a6afebc321d9 Mon Sep 17 00:00:00 2001 From: StrongSpoon Date: Thu, 29 May 2025 06:08:30 +0000 Subject: [PATCH 12/14] [Cache Tools] cache ascend tools --- python/setup_helper.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/setup_helper.py b/python/setup_helper.py index db4e24e2f..820345fe1 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -357,7 +357,7 @@ def git_clone(lib, lib_path): def handle_flagtree_backend(): global ext_sourcedir if flagtree_backend: - print(f"flagtree_backend is {flagtree_backend}") + print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") extend_backends.append(flagtree_backend) if "editable_wheel" in sys.argv and flagtree_backend != "ascend": ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" @@ -426,3 +426,12 @@ def check_env(env_val): pre_hock=lambda: check_env('LLVM_BUILD_DIR'), post_hock=set_llvm_env, ) + +# ascend +cache.store( + file="ascend-llvm-b5cc222d-ubuntu-x64.tar.gz", + condition=("ascend" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) From ba4f9e485bdce42414bd901ee5d7e9cdee90f6e2 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Wed, 4 Jun 2025 06:40:42 +0000 Subject: [PATCH 13/14] [BUILD] [DOC] Fix build ascend --- README.md | 8 +++++++- README_cn.md | 10 ++++++++-- python/requirements.txt | 1 + 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9a062e590..4525d1c1f 100644 --- a/README.md +++ b/README.md @@ -55,10 +55,16 @@ python3 -m pip install . --no-build-isolation -v # ascend # manually download LLVM cd ${YOUR_LLVM_DOWNLOAD_DIR} +# if the output of `uname -a` is x64 or x86_64 wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz -cd ${YOUR_CODE_DIR}/flagtree/python export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64 +# if the output of `uname -a` is aarch64 +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz +tar -zxvf llvm-b5cc222d-ubuntu-arm64.tar.gz +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-arm64 +# build +cd ${YOUR_CODE_DIR}/flagtree/python export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib export LLVM_SYSPATH=$LLVM_BUILD_DIR diff --git a/README_cn.md b/README_cn.md index 88c76cf15..f18b162b0 100644 --- a/README_cn.md +++ b/README_cn.md @@ -53,12 +53,18 @@ python3 -m pip install . --no-build-isolation -v ``` ```shell # ascend -# manually download LLVM +# 自行下载 LLVM cd ${YOUR_LLVM_DOWNLOAD_DIR} +# 如果 `uname -a` 的输出是 x64 或 x86_64 wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz -cd ${YOUR_CODE_DIR}/flagtree/python export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64 +# 如果 `uname -a` 的输出是 aarch64 +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz +tar -zxvf llvm-b5cc222d-ubuntu-arm64.tar.gz +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-arm64 +# 编译安装 +cd ${YOUR_CODE_DIR}/flagtree/python export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib export LLVM_SYSPATH=$LLVM_BUILD_DIR diff --git a/python/requirements.txt b/python/requirements.txt index 0ff83f18b..4f7fffe43 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -4,3 +4,4 @@ wheel GitPython pytest scipy +pybind11 From 3878dc0059a4c2c4759b9d142d478ce728f72f48 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Wed, 4 Jun 2025 15:06:17 +0800 Subject: [PATCH 14/14] [BUILD] Fix build ascend --- CMakeLists.txt | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 41aeea456..8509a2c62 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,6 @@ endif() project(triton) include(CTest) - if (FLAGTREE_BACKEND STREQUAL "ascend") set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") @@ -60,7 +59,7 @@ endif() # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) -option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON) +option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) if(FLAGTREE_BACKEND) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" OFF) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" OFF) @@ -69,6 +68,7 @@ else() option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) endif() set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") + # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests @@ -92,7 +92,7 @@ if(TRITON_BUILD_UT) include(AddTritonUnitTest) endif() -#Compiler flags +# Compiler flags set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}") include_directories(${BACKEND_INCLUDE_DIR}) @@ -165,7 +165,6 @@ function(add_triton_object name) target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) endif() endif() - endfunction(add_triton_object) set_property(GLOBAL PROPERTY TRITON_LIBS "") @@ -201,7 +200,6 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") endif() - include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) @@ -231,7 +229,6 @@ if (FLAGTREE_BACKEND STREQUAL "ascend") add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) endif() - # find_package(PythonLibs REQUIRED) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") @@ -425,18 +422,16 @@ if(TRITON_BUILD_PYTHON_MODULE) set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) - if(FLAGTREE_BACKEND STREQUAL "cambricon") add_library(triton SHARED) else() if(FLAGTREE_BACKEND STREQUAL "ascend") - set(IR_SRC ${PATCHED_PYTHON_SRC_PATH}/ir.cc) + set(PYTHON_IR_SRC_PATH ${PATCHED_PYTHON_SRC_PATH}) else() - set(IR_SRC ${PYTHON_SRC_PATH}/ir.cc) + set(PYTHON_IR_SRC_PATH ${PYTHON_SRC_PATH}) endif() - add_library(triton SHARED - ${PYTHON_SRC_PATH}/main.cc - ${IR_SRC} + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + ${PYTHON_IR_SRC_PATH}/ir.cc ${PYTHON_SRC_PATH}/passes.cc ${PYTHON_SRC_PATH}/interpreter.cc ${PYTHON_SRC_PATH}/llvm.cc)