From 80be602d1f92564f1f580619cc12d945bebdee3f Mon Sep 17 00:00:00 2001 From: ArinaJJH Date: Thu, 29 May 2025 15:01:05 +0800 Subject: [PATCH 1/6] [feat] Support metax plugin --- CMakeLists.txt | 34 + third_party/metax/CMakeLists.txt | 24 + third_party/metax/LICENSE | 24 + third_party/metax/backend/__init__.py | 0 third_party/metax/backend/compiler.py | 347 +++ third_party/metax/backend/driver.c | 170 + third_party/metax/backend/driver.py | 328 ++ third_party/metax/bin/CMakeLists.txt | 86 + .../metax/bin/RegisterTritonDialects.h | 30 + third_party/metax/bin/triton-llvm-opt.cpp | 121 + third_party/metax/bin/triton-lsp.cpp | 11 + third_party/metax/bin/triton-opt.cpp | 11 + third_party/metax/bin/triton-reduce.cpp | 11 + third_party/metax/include/CMakeLists.txt | 1 + .../metax/include/triton/CMakeLists.txt | 1 + .../include/triton/Target/CMakeLists.txt | 1 + .../triton/Target/LLVMIR/CMakeLists.txt | 3 + .../include/triton/Target/LLVMIR/Passes.h | 17 + .../include/triton/Target/LLVMIR/Passes.td | 15 + .../metax/include/triton/Tools/LinearLayout.h | 532 ++++ .../metax/include/triton/Tools/StrUtil.h | 54 + .../metax/include/triton/Tools/Sys/GetEnv.hpp | 81 + third_party/metax/lib/CMakeLists.txt | 1 + third_party/metax/lib/Target/CMakeLists.txt | 1 + .../metax/lib/Target/LLVMIR/CMakeLists.txt | 30 + .../metax/lib/Target/LLVMIR/LLVMDIScope.cpp | 161 + .../Target/LLVMIR/LLVMIRBreakPhiStruct.cpp | 60 + .../metax/lib/Target/LLVMIR/LLVMPasses.h | 16 + third_party/metax/python/src/interpreter.cc | 435 +++ third_party/metax/python/src/ir.cc | 1646 ++++++++++ third_party/metax/python/src/llvm.cc | 405 +++ third_party/metax/python/src/main.cc | 50 + third_party/metax/python/src/passes.cc | 90 + third_party/metax/python/src/passes.h | 40 + .../python/triton/_C/include/CMakeLists.txt | 1 + .../triton/_C/include/triton/CMakeLists.txt | 1 + .../_C/include/triton/Target/CMakeLists.txt | 1 + .../triton/Target/LLVMIR/CMakeLists.txt | 3 + .../_C/include/triton/Target/LLVMIR/Passes.h | 17 + .../_C/include/triton/Target/LLVMIR/Passes.td | 15 + .../_C/include/triton/Tools/LinearLayout.h | 532 ++++ .../triton/_C/include/triton/Tools/StrUtil.h | 54 + .../_C/include/triton/Tools/Sys/GetEnv.hpp | 81 + third_party/metax/python/triton/__init__.py | 73 + .../metax/python/triton/backends/__init__.py | 50 + .../metax/python/triton/backends/compiler.py | 76 + .../metax/python/triton/backends/driver.py | 34 + .../metax/python/triton/compiler/__init__.py | 4 + .../python/triton/compiler/code_generator.py | 1302 ++++++++ .../metax/python/triton/compiler/compiler.py | 412 +++ .../metax/python/triton/compiler/errors.py | 51 + .../python/triton/compiler/make_launcher.py | 0 third_party/metax/python/triton/errors.py | 5 + .../metax/python/triton/language/__init__.py | 284 ++ .../metax/python/triton/language/core.py | 2726 +++++++++++++++++ .../python/triton/language/extra/__init__.py | 4 + .../triton/language/extra/cuda/__init__.py | 8 + .../triton/language/extra/cuda/libdevice.py | 1629 ++++++++++ .../triton/language/extra/cuda/utils.py | 109 + .../triton/language/extra/hip/__init__.py | 3 + .../triton/language/extra/hip/libdevice.py | 468 +++ .../python/triton/language/extra/libdevice.py | 1214 ++++++++ .../metax/python/triton/language/math.py | 250 ++ .../metax/python/triton/language/random.py | 207 ++ .../metax/python/triton/language/semantic.py | 1624 ++++++++++ .../metax/python/triton/language/standard.py | 441 +++ .../metax/python/triton/ops/__init__.py | 7 + .../python/triton/ops/blocksparse/__init__.py | 7 + .../python/triton/ops/blocksparse/matmul.py | 432 +++ .../python/triton/ops/blocksparse/softmax.py | 228 ++ .../metax/python/triton/ops/cross_entropy.py | 96 + .../python/triton/ops/flash_attention.py | 466 +++ third_party/metax/python/triton/ops/matmul.py | 219 ++ .../python/triton/ops/matmul_perf_model.py | 171 ++ .../metax/python/triton/runtime/__init__.py | 23 + .../metax/python/triton/runtime/autotuner.py | 420 +++ .../metax/python/triton/runtime/build.py | 83 + .../metax/python/triton/runtime/cache.py | 281 ++ .../metax/python/triton/runtime/driver.py | 60 + .../metax/python/triton/runtime/errors.py | 26 + .../python/triton/runtime/interpreter.py | 1127 +++++++ .../metax/python/triton/runtime/jit.py | 956 ++++++ third_party/metax/python/triton/testing.py | 558 ++++ .../metax/python/triton/tools/__init__.py | 0 .../metax/python/triton/tools/build_extern.py | 365 +++ .../metax/python/triton/tools/compile.c | 67 + .../metax/python/triton/tools/compile.h | 14 + .../metax/python/triton/tools/compile.py | 145 + .../metax/python/triton/tools/disasm.py | 142 + third_party/metax/python/triton/tools/link.py | 322 ++ 90 files changed, 22731 insertions(+) create mode 100644 third_party/metax/CMakeLists.txt create mode 100644 third_party/metax/LICENSE create mode 100644 third_party/metax/backend/__init__.py create mode 100644 third_party/metax/backend/compiler.py create mode 100644 third_party/metax/backend/driver.c create mode 100644 third_party/metax/backend/driver.py create mode 100644 third_party/metax/bin/CMakeLists.txt create mode 100644 third_party/metax/bin/RegisterTritonDialects.h create mode 100644 third_party/metax/bin/triton-llvm-opt.cpp create mode 100644 third_party/metax/bin/triton-lsp.cpp create mode 100644 third_party/metax/bin/triton-opt.cpp create mode 100644 third_party/metax/bin/triton-reduce.cpp create mode 100644 third_party/metax/include/CMakeLists.txt create mode 100644 third_party/metax/include/triton/CMakeLists.txt create mode 100644 third_party/metax/include/triton/Target/CMakeLists.txt create mode 100644 third_party/metax/include/triton/Target/LLVMIR/CMakeLists.txt create mode 100644 third_party/metax/include/triton/Target/LLVMIR/Passes.h create mode 100644 third_party/metax/include/triton/Target/LLVMIR/Passes.td create mode 100644 third_party/metax/include/triton/Tools/LinearLayout.h create mode 100644 third_party/metax/include/triton/Tools/StrUtil.h create mode 100644 third_party/metax/include/triton/Tools/Sys/GetEnv.hpp create mode 100644 third_party/metax/lib/CMakeLists.txt create mode 100644 third_party/metax/lib/Target/CMakeLists.txt create mode 100644 third_party/metax/lib/Target/LLVMIR/CMakeLists.txt create mode 100644 third_party/metax/lib/Target/LLVMIR/LLVMDIScope.cpp create mode 100644 third_party/metax/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp create mode 100644 third_party/metax/lib/Target/LLVMIR/LLVMPasses.h create mode 100644 third_party/metax/python/src/interpreter.cc create mode 100644 third_party/metax/python/src/ir.cc create mode 100644 third_party/metax/python/src/llvm.cc create mode 100644 third_party/metax/python/src/main.cc create mode 100644 third_party/metax/python/src/passes.cc create mode 100644 third_party/metax/python/src/passes.h create mode 100644 third_party/metax/python/triton/_C/include/CMakeLists.txt create mode 100644 third_party/metax/python/triton/_C/include/triton/CMakeLists.txt create mode 100644 third_party/metax/python/triton/_C/include/triton/Target/CMakeLists.txt create mode 100644 third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/CMakeLists.txt create mode 100644 third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.h create mode 100644 third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.td create mode 100644 third_party/metax/python/triton/_C/include/triton/Tools/LinearLayout.h create mode 100644 third_party/metax/python/triton/_C/include/triton/Tools/StrUtil.h create mode 100644 third_party/metax/python/triton/_C/include/triton/Tools/Sys/GetEnv.hpp create mode 100644 third_party/metax/python/triton/__init__.py create mode 100644 third_party/metax/python/triton/backends/__init__.py create mode 100644 third_party/metax/python/triton/backends/compiler.py create mode 100644 third_party/metax/python/triton/backends/driver.py create mode 100644 third_party/metax/python/triton/compiler/__init__.py create mode 100644 third_party/metax/python/triton/compiler/code_generator.py create mode 100644 third_party/metax/python/triton/compiler/compiler.py create mode 100644 third_party/metax/python/triton/compiler/errors.py create mode 100644 third_party/metax/python/triton/compiler/make_launcher.py create mode 100644 third_party/metax/python/triton/errors.py create mode 100644 third_party/metax/python/triton/language/__init__.py create mode 100644 third_party/metax/python/triton/language/core.py create mode 100644 third_party/metax/python/triton/language/extra/__init__.py create mode 100644 third_party/metax/python/triton/language/extra/cuda/__init__.py create mode 100644 third_party/metax/python/triton/language/extra/cuda/libdevice.py create mode 100644 third_party/metax/python/triton/language/extra/cuda/utils.py create mode 100644 third_party/metax/python/triton/language/extra/hip/__init__.py create mode 100644 third_party/metax/python/triton/language/extra/hip/libdevice.py create mode 100644 third_party/metax/python/triton/language/extra/libdevice.py create mode 100644 third_party/metax/python/triton/language/math.py create mode 100644 third_party/metax/python/triton/language/random.py create mode 100644 third_party/metax/python/triton/language/semantic.py create mode 100644 third_party/metax/python/triton/language/standard.py create mode 100644 third_party/metax/python/triton/ops/__init__.py create mode 100644 third_party/metax/python/triton/ops/blocksparse/__init__.py create mode 100644 third_party/metax/python/triton/ops/blocksparse/matmul.py create mode 100644 third_party/metax/python/triton/ops/blocksparse/softmax.py create mode 100644 third_party/metax/python/triton/ops/cross_entropy.py create mode 100644 third_party/metax/python/triton/ops/flash_attention.py create mode 100644 third_party/metax/python/triton/ops/matmul.py create mode 100644 third_party/metax/python/triton/ops/matmul_perf_model.py create mode 100644 third_party/metax/python/triton/runtime/__init__.py create mode 100644 third_party/metax/python/triton/runtime/autotuner.py create mode 100644 third_party/metax/python/triton/runtime/build.py create mode 100644 third_party/metax/python/triton/runtime/cache.py create mode 100644 third_party/metax/python/triton/runtime/driver.py create mode 100644 third_party/metax/python/triton/runtime/errors.py create mode 100644 third_party/metax/python/triton/runtime/interpreter.py create mode 100644 third_party/metax/python/triton/runtime/jit.py create mode 100644 third_party/metax/python/triton/testing.py create mode 100644 third_party/metax/python/triton/tools/__init__.py create mode 100644 third_party/metax/python/triton/tools/build_extern.py create mode 100644 third_party/metax/python/triton/tools/compile.c create mode 100644 third_party/metax/python/triton/tools/compile.h create mode 100644 third_party/metax/python/triton/tools/compile.py create mode 100644 third_party/metax/python/triton/tools/disasm.py create mode 100644 third_party/metax/python/triton/tools/link.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e35abd8b..92cbecf11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,9 @@ else() endif() set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +if(FLAGTREE_BACKEND STREQUAL "metax") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_MACA -DUSE_MACA_OPAQUE_PTR -DUSE_BUILTIN -Wno-unused-result -Wno-attributes") +endif() # ######### # LLVM @@ -322,6 +325,33 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMXCNCodeGen LLVMXCNAsmParser ) + elseif(FLAGTREE_BACKEND STREQUAL "metax") + set(TRITON_LIBRARIES + ${triton_libs} + ${triton_plugins} + + # mlir + MLIRMACADialect + MLIRGPUToMACATransforms + MLIRGPUToGPURuntimeTransforms + MLIRGPUTransforms + MLIRIR + MLIRControlFlowToLLVM + MLIRBytecodeWriter + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRMathToLLVM + MLIRGPUDialect + MLIRSCFToControlFlow + MLIRIndexToLLVM + + # LLVM + LLVMPasses + LLVMNVPTXCodeGen + ) endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 @@ -352,6 +382,10 @@ if(TRITON_BUILD_PYTHON_MODULE) add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) if(FLAGTREE_BACKEND STREQUAL "cambricon") add_library(triton SHARED) + elseif(FLAGTREE_BACKEND STREQUAL "metax") + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + ${PYTHON_SRC_PATH}/interpreter.cc + ${PYTHON_SRC_PATH}/llvm.cc) else() add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc diff --git a/third_party/metax/CMakeLists.txt b/third_party/metax/CMakeLists.txt new file mode 100644 index 000000000..55efc4cc6 --- /dev/null +++ b/third_party/metax/CMakeLists.txt @@ -0,0 +1,24 @@ +add_subdirectory(include) +add_subdirectory(lib) + +if(TRITON_BUILD_PYTHON_MODULE) + if(FLAGTREE_PLUGIN) + add_subdirectory(plugin) + add_triton_plugin(TritonMetax + SHARED_LIB metaxTritonPlugin + ) + else() + find_library(metaxTritonPluginLib + NAMES + metaxTritonPlugin.so + PATHS + ${CMAKE_CURRENT_SOURCE_DIR} + REQUIRED + ) + add_triton_plugin(TritonMetax + SHARED_LIB ${metaxTritonPluginLib} + ) + endif() +endif() + +add_subdirectory(bin) \ No newline at end of file diff --git a/third_party/metax/LICENSE b/third_party/metax/LICENSE new file mode 100644 index 000000000..33d9e8945 --- /dev/null +++ b/third_party/metax/LICENSE @@ -0,0 +1,24 @@ +/* +* Copyright 2018-2020 Philippe Tillet +* Copyright 2020-2022 OpenAI +* Copyright 2025 MetaX Integrated Circuits (Shanghai) Co., Ltd. + +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ diff --git a/third_party/metax/backend/__init__.py b/third_party/metax/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/metax/backend/compiler.py b/third_party/metax/backend/compiler.py new file mode 100644 index 000000000..3e5a13808 --- /dev/null +++ b/third_party/metax/backend/compiler.py @@ -0,0 +1,347 @@ +''' Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, metax + +from dataclasses import dataclass +import functools +from typing import Any, Tuple, Optional +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path + + +@functools.lru_cache() +def _path_to_binary(binary: str): + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(os.path.dirname(__file__), "bin", binary), + ] + + for bin in paths: + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return bin, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def get_ptxas_version(): + version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8") + return version + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def maca_get_kernel_name(src: str) -> str: + ''' + Get kernel name from llvm ir. + This Kernel name is required when launching the kernel. + ''' + assert src + import re + for line in src.split('\n'): + line = line.strip() + if line.startswith('define metaxgpu_kernel void @'): + return re.match(r"define metaxgpu_kernel void @(.+?)\(", line).groups()[0] + +def parse_option(string): + return [item for item in string.split(';') if item] + + +@dataclass(frozen=True) +class MACAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'maca' + # MACA: new args + pipeline: str = "basic" + scenario: str = "" + extra_options: str = "" + pipeline_load_num: int = -1 + + def __post_init__(self): + default_libdir = os.getenv("MACA_PATH") + '/lib' + ext_default_libdir = Path(__file__).parent / 'lib' + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + # ext_maca_mathlib.bc + env_ext_libdevice_path = os.getenv("TRITON_EXT_LIBDEVICE_PATH", None) + ext_libdevice_path = env_ext_libdevice_path if env_ext_libdevice_path is not None else str(ext_default_libdir) + '/ext_maca_mathlib.bc' + assert os.path.exists(ext_libdevice_path), "ext_maca_mathlib.bc do not exit, please check!" + extern_libs['ext_libdevice'] = ext_libdevice_path + # maca_kernellib.bc + env_kernel_libdevice_path = os.getenv("TRITON_KERNEL_LIBDEVICE_PATH", None) + kernel_libdevice_path = env_kernel_libdevice_path if env_kernel_libdevice_path is not None else default_libdir + '/maca_kernellib.bc' + extern_libs['kernel_libdevice'] = kernel_libdevice_path + # maca_mathlib.bc + env_libdevice_path = os.getenv("TRITON_LIBDEVICE_PATH", None) + libdevice_path = env_libdevice_path if env_libdevice_path is not None else default_libdir + '/maca_mathlib.bc' + extern_libs['libdevice'] = libdevice_path + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and self.num_warps <= 16 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2 or greater than 0 and less than or equal to 16" + + def hash(self): + hash_dict = dict(self.__dict__) + # hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class MACABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'maca' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.capability = target.arch + assert isinstance(self.capability, int) + self.binary_ext = "mcfatbin" + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in MACAOptions.__dataclass_fields__.keys() if k in opts} + # USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn) + args["allow_fp8e4nv"] = True + # args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return MACAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self): + import triton.language.extra.cuda as cuda + codegen_fns = { + "convert_custom_types": + cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70 + } + return codegen_fns + + def load_dialects(self, ctx): + metax.load_dialects(ctx) + + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + assert opt.pipeline_load_num >= -1, "invalid pipeline_load_num value!" + scenarios = parse_option(opt.scenario) + disable_prefetch = "unprefetch" in scenarios + fullstage = "fullstage" in scenarios + store_coalesce = "storeCoalesce" in scenarios + mla = "mla" in scenarios + single_shm = "singleshm" in scenarios + use_opt_maca_mma = True + use_opt_maca_mma = (opt.pipeline != "" and not os.getenv("TRITON_DISABLE_MACA_OPT_MMA")) + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 64, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + + if opt.pipeline == "cpasync" : + disable_prefetch = True + metax.passes.ttgpuir.add_accelerate_matmul(pm, opt.num_stages, disable_prefetch, store_coalesce, "c500") + passes.ttgpuir.add_remove_layout_conversions(pm) + if store_coalesce: + metax.passes.ttgpuir.add_tritonmetaxgpu_change_layout_from_repn_to_elemn_pass(pm) + metax.passes.ttgpuir.add_tritonmetaxgpu_optimize_cstore_pass(pm, opt.num_stages) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.common.add_cse(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + if use_opt_maca_mma: + if opt.pipeline == "basic": + if mla and single_shm: + # only mla=True and single_shm=True + metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, True) + else: + metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, False) + elif opt.pipeline == "cpasync" and not mla: + metax.passes.ttgpuir.add_pipeline_async_tn(pm, opt.num_stages) + metax.passes.ttgpuir.add_pipeline_async_tt(pm, opt.num_stages) + metax.passes.ttgpuir.add_pipeline_async_base(pm, opt.num_stages, fullstage) + elif mla and opt.num_stages == 2 and opt.pipeline == "cpasync": + metax.passes.ttgpuir.add_pipeline_async_multidot_mla(pm, opt.num_stages, fullstage, opt.pipeline_load_num) + else: + print("no avalilable pipeline for maca") + else: + passes.ttgpuir.add_pipeline(pm, opt.num_stages) + if use_opt_maca_mma and opt.pipeline == "basic" and "unprefetch" not in scenarios: + metax.passes.ttgpuir.add_prefetch_maca(pm) + elif not use_opt_maca_mma: + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + if os.getenv("TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP"): + metax.passes.ttgpuir.add_tritonmetaxgpu_move_dot_operands_out_loop_pass(pm) + if os.getenv("TRITON_ENABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT"): + metax.passes.ttgpuir.add_tritonmetaxgpu_merge_equal_shared_layout_pass(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + @staticmethod + def make_mlir(src, metadata, options, capability): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + mod = src + + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + metax.passes.ttgpuir.add_to_llvmir(pm, capability) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + + # Get some metadata + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(mod) + return ret + + @staticmethod + def make_llir(src, metadata, options, capability): + mlir_opt_path = os.path.dirname(__file__) + "/bin/mlir-opt" + opted_mlir = metax.mlir_opt(src, mlir_opt_path) + mlir_translate_path = os.path.dirname(__file__) + "/bin/mlir-translate" + maca_path = os.environ.get('MACA_PATH') + assert maca_path, "Not found MACA_PATH" + llir = metax.translate_mlir_to_llir(opted_mlir, maca_path) + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llir = metax.link_extern_libs(llir, paths, maca_path) + metadata["name"] = maca_get_kernel_name(llir) + return llir + + + @staticmethod + def make_mcfatbin(src, metadata, opt, capability): + scenarios = parse_option(opt.scenario) + opt_mxcc = os.environ.get("TRITON_COMPILER_OPT_PATH") + mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" + if opt_mxcc: + mxcc_arch = opt_mxcc + "/mxgpu_llvm/bin/mxcc" + if mxcc_arch is None: + raise RuntimeError('mxcc_arch is None (not specified)') + compile_options = "" + if (opt.pipeline == "basic" or opt.pipeline == "basic-prefetch") and "mla" not in scenarios: + compile_options = " -mllvm -metaxgpu-sched-regpressure=false -mllvm -metaxgpu-PostRA-Scheduler=false -mllvm -metaxgpu-mma-sched=true " + if "fullstage" in scenarios: + compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup " + else: + compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-sched-mma-maxnum=3 " + if "roll" not in scenarios: + compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " + elif opt.pipeline == "cpasync" and "mla" not in scenarios: + compile_options = " -mllvm -metaxgpu-sched-regpressure=true -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true \ + -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true -mllvm -metaxgpu-shl-add-combine=false \ + -mllvm -misched-postra=true -mllvm -enable-post-misched=true " + if os.getenv("TRITON_ENABLE_MACA_COMPILER_INT8_OPT"): + compile_options += " -mllvm -metaxgpu-slp-vectorize-i8=true" + if "unroll" in scenarios: + compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " + if opt.extra_options != "": + compile_options = opt.extra_options + return metax.translate_llvmir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH'), compile_options) + + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + stages["mlir"] = lambda src, metadata: self.make_mlir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["mcfatbin"] = lambda src, metadata: self.make_mcfatbin(src, metadata, options, self.capability) + + @functools.lru_cache() + def hash(self): + mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" + if mxcc_arch is None: + raise RuntimeError('mxcc_arch is None (not specified)') + version = subprocess.check_output([mxcc_arch, "--version"]).decode("utf-8").split('\n', 1)[0] + return f'{version}-{self.capability}' diff --git a/third_party/metax/backend/driver.c b/third_party/metax/backend/driver.c new file mode 100644 index 000000000..8b417b529 --- /dev/null +++ b/third_party/metax/backend/driver.c @@ -0,0 +1,170 @@ +/* Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. */ +#include +#include +#include +#define PY_SSIZE_T_CLEAN +#include +#include + +// Raises a Python exception and returns false if code is not MC_SUCCESS. +static bool gpuAssert(mcError_t code, const char *file, int line) { + if (code == mcSuccess) + return true; + + const char *prefix = "Triton Error [MACA]: "; + const char *str = mcGetErrorString(code); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MACA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + MCdevice device; + mcDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem = 64 * 1024; // 64KB, no CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN + int max_num_regs; + int multiprocessor_count; + int warp_size = 64; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &max_num_regs, mcDeviceAttributeMaxSharedMemoryPerBlock, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &multiprocessor_count, mcDeviceAttributeMultiProcessorCount, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &sm_clock_rate, mcDeviceAttributeClockRate, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &mem_clock_rate, mcDeviceAttributeMemoryClockRate, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &mem_bus_width, mcDeviceAttributeMemoryBusWidth, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + mcFunction_t fun; + mcModule_t mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + MCcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxGetCurrent(&pctx)); + if (!pctx) { + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcDevicePrimaryCtxRetain(&pctx, device)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxSetCurrent(pctx)); + } + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleLoadData(&mod, data)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcFuncGetAttribute(&n_regs, MC_FUNC_ATTRIBUTE_NUM_REGS, fun)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcFuncGetAttribute(&n_spills, MC_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + return Py_None; +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "maca_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_maca_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/metax/backend/driver.py b/third_party/metax/backend/driver.py new file mode 100644 index 000000000..000f3302a --- /dev/null +++ b/third_party/metax/backend/driver.py @@ -0,0 +1,328 @@ +''' Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +import functools +import os +import hashlib +import subprocess +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] +libdevice_dir = os.path.join(dirname, "lib") +# libraries = ['cuda'] +libraries = [] + +@functools.lru_cache() +def maca_home_dirs(): + return os.getenv("MACA_PATH") + +@functools.lru_cache() +def libmaca_dirs(): + maca_path = maca_home_dirs() + return ["{}/lib/".format(maca_path)] + +maca_lib_dir = libmaca_dirs() +maca_include_dir = [os.path.join(maca_home_dirs(), "include")] + + +@functools.lru_cache() +def library_dirs(): + return [libdevice_dir, *libmaca_dirs()] + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + # TODO(MACA): fix it + so = _build(name, src_path, tmpdir, library_dirs(), maca_include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class MacaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(MacaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "maca_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + # self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = mod.set_printf_fifo_size + # self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor + # self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "mcDeviceptr_t" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + params = [i for i in signature.keys() if i not in constants] + src = f""" +#include +#include +#include +#include + +static inline void gpuAssert(mcError_t code, const char *file, int line) +{{ + if (code != mcSuccess) + {{ + const char* prefix = "Triton Error [MACA]: "; + const char* str = mcGetErrorString(code); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define MACA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, mcStream_t stream, mcFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + assert(num_ctas == 1); + MACA_CHECK(mcModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} +}} + +typedef struct _DevicePtrInfo {{ + mcDeviceptr_t dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = mcPointerGetAttribute(&dev_ptr, mcPointerAttributeDevicePointer, ptr_info.dev_ptr); + if (status == mcErrorInvalidValue) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = (mcDeviceptr_t)dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (mcStream_t)_stream, (mcFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class MacaLauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class MacaDriver(GPUDriver): + + def __init__(self): + self.utils = MacaUtils() # TODO: make static + self.launcher_cls = MacaLauncher + super().__init__() + + def get_current_target(self): + device = self.get_current_device() + capability = self.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + warp_size = 64 + return GPUTarget("maca", capability, warp_size) + + @staticmethod + def is_active(): + import torch + return torch.cuda.is_available() and (torch.version.hip is None) diff --git a/third_party/metax/bin/CMakeLists.txt b/third_party/metax/bin/CMakeLists.txt new file mode 100644 index 000000000..b3d7981ae --- /dev/null +++ b/third_party/metax/bin/CMakeLists.txt @@ -0,0 +1,86 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-opt) +target_link_libraries(triton-opt PRIVATE + TritonLLVMIR + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-opt) + +add_llvm_executable(triton-reduce triton-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-reduce) + +llvm_update_compile_flags(triton-reduce) +target_link_libraries(triton-reduce PRIVATE + TritonLLVMIR + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-reduce PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-reduce) + +add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-lsp) + +llvm_update_compile_flags(triton-lsp) +target_link_libraries(triton-lsp PRIVATE + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-lsp PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-lsp) + +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${LLVM_INCLUDE_DIRS}) +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) +set_target_properties(triton-llvm-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) \ No newline at end of file diff --git a/third_party/metax/bin/RegisterTritonDialects.h b/third_party/metax/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..d002ce597 --- /dev/null +++ b/third_party/metax/bin/RegisterTritonDialects.h @@ -0,0 +1,30 @@ +/* 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. */ +#pragma once + +#include "triton/Target/LLVMIR/Passes.h" +#include "mlir/InitAllPasses.h" +#include "python/src/plugin.h" + +using BackendRegisterFunc = void (*)(); +BackendRegisterFunc load_backend_register_func(const char *backend_name, const char *func_name) { + void *symbol = load_backend_symbol(backend_name, func_name); + return reinterpret_cast(symbol); +} + +using DialectRegisterFunc = void (*)(mlir::DialectRegistry*); +DialectRegisterFunc load_dialect_register_func(const char *backend_name, const char *func_name) { + void *symbol = load_backend_symbol(backend_name, func_name); + return reinterpret_cast(symbol); +} + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + + auto registerAllTritonPasses = load_backend_register_func("metax", "registerAllTritonPasses"); + registerAllTritonPasses(); + auto registerConvertTritonGPUToLLVMPass = load_backend_register_func("metax", "registerConvertTritonGPUToLLVMPass"); + registerConvertTritonGPUToLLVMPass(); + + auto registerDialect = load_dialect_register_func("metax", "registerDialect"); + registerDialect(®istry); +} diff --git a/third_party/metax/bin/triton-llvm-opt.cpp b/third_party/metax/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/metax/bin/triton-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/metax/bin/triton-lsp.cpp b/third_party/metax/bin/triton-lsp.cpp new file mode 100644 index 000000000..b185b0374 --- /dev/null +++ b/third_party/metax/bin/triton-lsp.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/metax/bin/triton-opt.cpp b/third_party/metax/bin/triton-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/metax/bin/triton-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/metax/bin/triton-reduce.cpp b/third_party/metax/bin/triton-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/metax/bin/triton-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/metax/include/CMakeLists.txt b/third_party/metax/include/CMakeLists.txt new file mode 100644 index 000000000..72181b98f --- /dev/null +++ b/third_party/metax/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) \ No newline at end of file diff --git a/third_party/metax/include/triton/CMakeLists.txt b/third_party/metax/include/triton/CMakeLists.txt new file mode 100644 index 000000000..7369bfadf --- /dev/null +++ b/third_party/metax/include/triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Target) \ No newline at end of file diff --git a/third_party/metax/include/triton/Target/CMakeLists.txt b/third_party/metax/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/metax/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/metax/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/metax/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/metax/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/metax/include/triton/Target/LLVMIR/Passes.h b/third_party/metax/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/metax/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/metax/include/triton/Target/LLVMIR/Passes.td b/third_party/metax/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/metax/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/metax/include/triton/Tools/LinearLayout.h b/third_party/metax/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..fb2680241 --- /dev/null +++ b/third_party/metax/include/triton/Tools/LinearLayout.h @@ -0,0 +1,532 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout is the function that, given a "hardware location" tuple of (thread-id, +// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) +// is our linear layout func, then a register in thread t in warp w contains the +// value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs operations, where the layout is logically flattened +// according to the dimension order and then chopped up again. +// +// ## Surjectivity +// +// We require that all output values are covered by some input value, i.e. the +// function L is surjective. But multiple input values can map to the same +// output value. This represents the idea that the same logical tensor element +// can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comaprison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composible. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::SetVector outDimNames; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). + static LinearLayout zeros1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + ; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + ArrayRef getOutDimNames() const { + return outDimNames.getArrayRef(); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { + return outDimNames.contains(outDim); + } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDimNames.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + // Creates a new layout which, roughly speaking, is equivalent to one where + // every element of the `outer` layout is replaced by a full instance of the + // `inner` layout. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // + // Notice that this operation is not commutative. It's also not associative. + // TODO(jlebar): Can I modify the definition to make it associative? Pretty + // confusing if not. If I can't, add an example. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: The output dimensions of this layout equal the input dimensions + // of outer (order doesn't matter). + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeIns( + // std::vector> + // newInDims) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeOuts( + // std::vector> + // newOutDims) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif diff --git a/third_party/metax/include/triton/Tools/StrUtil.h b/third_party/metax/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/metax/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/metax/include/triton/Tools/Sys/GetEnv.hpp b/third_party/metax/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..12584aa8f --- /dev/null +++ b/third_party/metax/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,81 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_LLVM_DEBUG_ONLY", + "USE_TTGIR_LOC", + "NVPTX_ENABLE_DUMP", + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + "TRITON_REPRODUCER_PATH", +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/metax/lib/CMakeLists.txt b/third_party/metax/lib/CMakeLists.txt new file mode 100644 index 000000000..7369bfadf --- /dev/null +++ b/third_party/metax/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Target) \ No newline at end of file diff --git a/third_party/metax/lib/Target/CMakeLists.txt b/third_party/metax/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/metax/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/metax/lib/Target/LLVMIR/CMakeLists.txt b/third_party/metax/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..7b0d34d35 --- /dev/null +++ b/third_party/metax/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,30 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRPass + MLIRGPUDialect + MLIRGPUTransforms + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/metax/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/metax/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 000000000..af7079060 --- /dev/null +++ b/third_party/metax/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Target/LLVMIR/Passes.h.inc" + +namespace { + +/// Attempt to extract a filename for the given loc. +FileLineColLoc extractFileLoc(Location loc) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return extractFileLoc(callerLoc.getCaller()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public LLVMDIScopeBase { + LLVMDIScopePass() = default; + + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + auto subroutineTypeAttr = + LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {}); + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, + /*line=*/line, + /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + // Get a nested loc for inlined functions + Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, + Location calleeLoc) { + auto calleeFileName = extractFileLoc(calleeLoc).getFilename(); + auto context = op->getContext(); + LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(calleeFileName), + llvm::sys::path::parent_path(calleeFileName)); + auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get( + context, scopeAttr, calleeFileAttr, /*discriminator=*/0); + Location loc = calleeLoc; + if (mlir::isa(calleeLoc)) { + auto nestedLoc = mlir::cast(calleeLoc).getCallee(); + loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc); + } + return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); + } + + void setLexicalBlockFileAttr(Operation *op) { + auto opLoc = op->getLoc(); + if (auto callSiteLoc = dyn_cast(opLoc)) { + auto callerLoc = callSiteLoc.getCaller(); + auto calleeLoc = callSiteLoc.getCallee(); + LLVM::DIScopeAttr scopeAttr; + // We assemble the full inline stack so the parent of this loc must be a + // function + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + scopeAttr = mlir::cast(funcOpLoc.getMetadata()); + auto loc = + CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc); + op->setLoc(loc); + } + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr mlir::createLLVMDIScopePass() { + return std::make_unique(); +} diff --git a/third_party/metax/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/metax/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..44afcfd21 --- /dev/null +++ b/third_party/metax/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/metax/lib/Target/LLVMIR/LLVMPasses.h b/third_party/metax/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/third_party/metax/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/metax/python/src/interpreter.cc b/third_party/metax/python/src/interpreter.cc new file mode 100644 index 000000000..6ab7c6c75 --- /dev/null +++ b/third_party/metax/python/src/interpreter.cc @@ -0,0 +1,435 @@ +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, + {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, + {MemSemantic::RELEASE, __ATOMIC_RELEASE}, + {MemSemantic::RELAXED, __ATOMIC_RELAXED}, +}; + +// Use compiler builtin atomics instead of std::atomic which requires +// each variable to be declared as atomic. +// Currently work for clang and gcc. +template T atomic_cmp(T *ptr, T val, int order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + // First load + T old_val = __atomic_load_n(ptr, order); + while (cmp(old_val, val)) { + if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + break; + } + } + return old_val; +} + +template T atomic_fadd(T *ptr, T val, int order) { + T old_val; + T new_val; + // First load + // Load ptr as if uint32_t or uint64_t and then memcpy to T + if constexpr (sizeof(T) == 4) { + uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else if constexpr (sizeof(T) == 8) { + uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else { + throw std::invalid_argument("Unsupported data type"); + } + while (true) { + new_val = old_val + val; + if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, + order)) { + break; + } + } + return old_val; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, int order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + int order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, int order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + *(static_cast(ret) + i) = + applyAtMasked(static_cast(loc), + *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_add(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_and(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_or(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_xor(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_exchange_n(loc, value, order); + } +}; + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, int order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + uint8_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 2) { + uint16_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 4) { + uint32_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 8) { + uint64_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else { + // The ‘__atomic’ builtins can be used with any integral scalar or pointer + // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are + // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the + // architecture. + // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + int order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, int order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/metax/python/src/ir.cc b/third_party/metax/python/src/ir.cc new file mode 100644 index 000000000..0befdc491 --- /dev/null +++ b/third_party/metax/python/src/ir.cc @@ -0,0 +1,1646 @@ +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::class_(m, "context", py::module_local()).def(py::init<>()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](Block *block) { + Operation *retOp = nullptr; + // It's better to not use walk here because we only want to + // check operations in the current block + for (auto &op : block->getOperations()) { + if (isa(op)) + if (retOp == nullptr) { + retOp = &op; + break; + } + } + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + // 2. Check if the result of tl.advance is used + self.walk([&](Operation *op) { + if (isa(op) && op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + // TODO: fp8e4nv is using Float8E4M3FNUZType, which + // does not seem right. It should use FloatE4M3FNType + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value &desc_ptr, + std::vector &indices, Type type, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + return self.create( + type, desc_ptr, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value &desc_ptr, Value value, + std::vector &indices) -> void { + self.create(desc_ptr, value, + indices); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values) -> void { + self.create( + StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)), + hex, values); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message, const std::string &fileName, + const std::string &funcName, unsigned lineNo) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(fileName)); + auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(funcName)); + auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo); + self.create(condition, messageAttr, fileNameAttr, + funcNameAttr, lineNoAttr); + }) + // Undef + .def("create_undef", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [](Pass *, Operation *) { return true; }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/metax/python/src/llvm.cc b/third_party/metax/python/src/llvm.cc new file mode 100644 index 000000000..0039d1a2f --- /dev/null +++ b/third_party/metax/python/src/llvm.cc @@ -0,0 +1,405 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + } + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(triple); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); + llvm::TargetOptions opt; + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + std::unique_ptr machine{target->createTargetMachine( + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + for (llvm::Function &f : module.functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + const std::string triple) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + tuningOptions.SLPVectorization = true; + + if (!triple.empty()) + mod->setTargetTriple(triple.c_str()); + + PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt, + instrCbPtr); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} diff --git a/third_party/metax/python/src/main.cc b/third_party/metax/python/src/main.cc new file mode 100644 index 000000000..5ad4be7d5 --- /dev/null +++ b/third_party/metax/python/src/main.cc @@ -0,0 +1,50 @@ +#include +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/metax/python/src/passes.cc b/third_party/metax/python/src/passes.cc new file mode 100644 index 000000000..513e811d2 --- /dev/null +++ b/third_party/metax/python/src/passes.cc @@ -0,0 +1,90 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPUPass, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", + createAllocateSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", createLLVMDIScopePass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/metax/python/src/passes.h b/third_party/metax/python/src/passes.h new file mode 100644 index 000000000..46801d802 --- /dev/null +++ b/third_party/metax/python/src/passes.h @@ -0,0 +1,40 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ + pm.addPass(builder({val0, val1, val2, val3})); \ + }) diff --git a/third_party/metax/python/triton/_C/include/CMakeLists.txt b/third_party/metax/python/triton/_C/include/CMakeLists.txt new file mode 100644 index 000000000..72181b98f --- /dev/null +++ b/third_party/metax/python/triton/_C/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) \ No newline at end of file diff --git a/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt b/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt new file mode 100644 index 000000000..7369bfadf --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Target) \ No newline at end of file diff --git a/third_party/metax/python/triton/_C/include/triton/Target/CMakeLists.txt b/third_party/metax/python/triton/_C/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.h b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.td b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/metax/python/triton/_C/include/triton/Tools/LinearLayout.h b/third_party/metax/python/triton/_C/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..fb2680241 --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Tools/LinearLayout.h @@ -0,0 +1,532 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout is the function that, given a "hardware location" tuple of (thread-id, +// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) +// is our linear layout func, then a register in thread t in warp w contains the +// value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs operations, where the layout is logically flattened +// according to the dimension order and then chopped up again. +// +// ## Surjectivity +// +// We require that all output values are covered by some input value, i.e. the +// function L is surjective. But multiple input values can map to the same +// output value. This represents the idea that the same logical tensor element +// can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comaprison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composible. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::SetVector outDimNames; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). + static LinearLayout zeros1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + ; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + ArrayRef getOutDimNames() const { + return outDimNames.getArrayRef(); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { + return outDimNames.contains(outDim); + } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDimNames.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + // Creates a new layout which, roughly speaking, is equivalent to one where + // every element of the `outer` layout is replaced by a full instance of the + // `inner` layout. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // + // Notice that this operation is not commutative. It's also not associative. + // TODO(jlebar): Can I modify the definition to make it associative? Pretty + // confusing if not. If I can't, add an example. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: The output dimensions of this layout equal the input dimensions + // of outer (order doesn't matter). + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeIns( + // std::vector> + // newInDims) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeOuts( + // std::vector> + // newOutDims) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif diff --git a/third_party/metax/python/triton/_C/include/triton/Tools/StrUtil.h b/third_party/metax/python/triton/_C/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/metax/python/triton/_C/include/triton/Tools/Sys/GetEnv.hpp b/third_party/metax/python/triton/_C/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..12584aa8f --- /dev/null +++ b/third_party/metax/python/triton/_C/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,81 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_LLVM_DEBUG_ONLY", + "USE_TTGIR_LOC", + "NVPTX_ENABLE_DUMP", + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + "TRITON_REPRODUCER_PATH", +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/metax/python/triton/__init__.py b/third_party/metax/python/triton/__init__.py new file mode 100644 index 000000000..a5f77f91e --- /dev/null +++ b/third_party/metax/python/triton/__init__.py @@ -0,0 +1,73 @@ +"""isort:skip_file""" +__version__ = '3.1.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError + +from . import language +from . import testing +from . import tools + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "impl", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "OutOfResources", + "reinterpret", + "runtime", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/metax/python/triton/backends/__init__.py b/third_party/metax/python/triton/backends/__init__.py new file mode 100644 index 000000000..fbf65d9e9 --- /dev/null +++ b/third_party/metax/python/triton/backends/__init__.py @@ -0,0 +1,50 @@ +import os +import importlib.util +import inspect +from dataclasses import dataclass +from .driver import DriverBase +from .compiler import BaseBackend + + +def _load_module(name, path): + spec = importlib.util.spec_from_file_location(name[:-3], path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _find_concrete_subclasses(module, base_class): + ret = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: BaseBackend = None + driver: DriverBase = None + + +def _discover_backends(): + backends = dict() + root = os.path.dirname(__file__) + for name in os.listdir(root): + if not os.path.isdir(os.path.join(root, name)): + continue + if name.startswith('__'): + continue + compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) + driver = _load_module(name, os.path.join(root, name, 'driver.py')) + backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + _find_concrete_subclasses(driver, DriverBase)) + return backends + + +backends = _discover_backends() diff --git a/third_party/metax/python/triton/backends/compiler.py b/third_party/metax/python/triton/backends/compiler.py new file mode 100644 index 000000000..990690045 --- /dev/null +++ b/third_party/metax/python/triton/backends/compiler.py @@ -0,0 +1,76 @@ +import os +import re +import subprocess + +from abc import ABCMeta, abstractmethod, abstractclassmethod +from dataclasses import dataclass +from typing import Union + + +@dataclass(frozen=True) +class GPUTarget(object): + # Target backend, e.g., cuda, hip + backend: str + # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) + arch: Union[int, str] + warp_size: int + + +class BaseBackend(metaclass=ABCMeta): + + def __init__(self, target: GPUTarget) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + def _path_to_binary(binary: str): + base_dir = os.path.join(os.path.dirname(__file__), os.pardir) + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(base_dir, "third_party", "cuda", "bin", binary), + ] + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + @abstractclassmethod + def supports_target(target: GPUTarget): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError diff --git a/third_party/metax/python/triton/backends/driver.py b/third_party/metax/python/triton/backends/driver.py new file mode 100644 index 000000000..e66442943 --- /dev/null +++ b/third_party/metax/python/triton/backends/driver.py @@ -0,0 +1,34 @@ +from abc import ABCMeta, abstractmethod, abstractclassmethod + + +class DriverBase(metaclass=ABCMeta): + + @abstractclassmethod + def is_active(self): + pass + + @abstractmethod + def get_current_target(self): + pass + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + self.get_device_capability = torch.cuda.get_device_capability + try: + from torch._C import _cuda_getCurrentRawStream + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/third_party/metax/python/triton/compiler/__init__.py b/third_party/metax/python/triton/compiler/__init__.py new file mode 100644 index 000000000..ce0cfedfc --- /dev/null +++ b/third_party/metax/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/metax/python/triton/compiler/code_generator.py b/third_party/metax/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..6903052ca --- /dev/null +++ b/third_party/metax/python/triton/compiler/code_generator.py @@ -0,0 +1,1302 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..runtime.jit import _normalize_ty +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.debug = options.debug if debug is None else debug + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if (val is absent # + or name in self.builtin_namespace # + or type(val) == ModuleType # + or isinstance(val, JITFunction) # + or getattr(val, "__triton_builtin__", False) # + or getattr(val, "__module__", "").startswith("triton.language") # + or isinstance(val, language.dtype) # + or self._is_constexpr_global(name) # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + or self.visiting_arg_default_value # + or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + # update return type + if isinstance(self.ret_type, tuple): + self.prototype.ret_types = list(self.ret_type) + self.fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + self.fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = _get_fn_file_line(fn) + debug = self.debug if fn.debug is None else fn.debug + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/metax/python/triton/compiler/compiler.py b/third_party/metax/python/triton/compiler/compiler.py new file mode 100644 index 000000000..367aa1b1a --- /dev/null +++ b/third_party/metax/python/triton/compiler/compiler.py @@ -0,0 +1,412 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +# TODO: this shouldn't be here +from dataclasses import dataclass +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + + +@dataclass +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + + def to_dict(self): + return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} + + @staticmethod + def from_dict(data): + return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), + equal_to_1=set(data.get('equal_to_1', []))) + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + metadata_filename = f"{src.name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + try: + module = src.make_ir(options, codegen_fns, context) + except Exception as e: + filter_traceback(e) + raise + use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{src.name}.{ext}" + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name, ext, context) + # use an env variable to parse ttgir from file + if use_ttgir_loc and ext == "ttgir": + ttgir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ttgir_full_name) + print(f"Create new locations for {ttgir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/metax/python/triton/compiler/errors.py b/third_party/metax/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/metax/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/metax/python/triton/compiler/make_launcher.py b/third_party/metax/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/metax/python/triton/errors.py b/third_party/metax/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/metax/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/metax/python/triton/language/__init__.py b/third_party/metax/python/triton/language/__init__.py new file mode 100644 index 000000000..168dccfea --- /dev/null +++ b/third_party/metax/python/triton/language/__init__.py @@ -0,0 +1,284 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + advance, + arange, + associative_scan, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + const_pointer_type, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "abs", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "const_pointer_type", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "function_type", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + if name[0] == "*": + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) + return pointer_type(ty) + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/metax/python/triton/language/core.py b/third_party/metax/python/triton/language/core.py new file mode 100644 index 000000000..a343a2887 --- /dev/null +++ b/third_party/metax/python/triton/language/core.py @@ -0,0 +1,2726 @@ +''' 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic + +T = TypeVar('T') + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + +if os.getenv("MACA_PATH") is not None: + USE_MACA = True + # map all the libdevice.10.bc funcs to maca in maca_mathlib.bc, do not modify the specific func in libdevice.py + # TODO(MACA): add python/triton/language/extra/maca/ directory later + nv_to_maca_map = { + "__nv_floorf" : "mc_math_func_floorf", + "__nv_floor" : "mc_math_func_floor", + "__nv_log2f" : "mc_math_func_log2f", + "__nv_log2" : "mc_math_func_log2", + "__nv_logf" : "mc_math_func_logf", + "__nv_powf" : "mc_math_func_powf_inline", + "__nv_pow" : "mc_math_func_pow", + "__nv_norm4df" : "mc_math_func_norm4df", + "__nv_norm4d" : "mc_math_func_norm4d", + "__nv_expf" : "mc_math_func_expf", + "__nv_exp" : "mc_math_func_exp", + "__nv_ffs" : "mc_math_func_ffs", + "__nv_umulhi" : "mc_math_func_umulhi", + "__nv_umul64hi" : "mc_math_func_umul64hi", + "__nv_rsqrtf" : "mc_math_func_rsqrtf", + "__nv_erff" : "mc_math_func_erff", + "__nv_tanhf" : "mc_math_func_tanhf", + "__nv_max" : "mc_math_func_max", + "__nv_umax" : "mc_math_func_umax", + "__nv_fmaxf" : "mc_math_func_fmaxf", + "__nv_fmax" : "mc_math_func_fmax", + "__nv_llmax" : "mc_math_func_llmax", + "__nv_ullmax" : "mc_math_func_ullmax", + "__nv_min" : "mc_math_func_min", + "__nv_umin" : "mc_math_func_umin", + "__nv_fminf" : "mc_math_func_fminf", + "__nv_fmin" : "mc_math_func_fmin", + "__nv_llmin" : "mc_math_func_llmin", + "__nv_ullmin" : "mc_math_func_ullmin", + "__nv_isinff" : "mc_math_func_isinff", + "__nv_log1pf" : "mc_math_func_log1pf", + "__nv_truncf" : "mc_math_func_truncf", + "__nv_expm1f" : "mc_math_func_expm1f", + "__nv_exp2f" : "mc_math_func_exp2f", + "__nv_fmodf" : "mc_math_func_fmodf", + "__nv_lgammaf" : "mc_math_func_lgammaf", + "__nv_log" : "mc_math_func_log", + "__nv_nearbyintf" : "mc_math_func_nearbyintf", + "__nv_signbitf" : "mc_math_func_signbitf", + "__nv_tanf" : "mc_math_func_tanf", + "__nv_ceilf" : "mc_math_func_ceilf", + "__nv_acosf" : "mc_math_func_acosf", + "__nv_acoshf" : "mc_math_func_acoshf", + "__nv_acos" : "mc_math_func_acos", + "__nv_acosh" : "mc_math_func_acosh", + "__nv_asinf" : "mc_math_func_asinf", + "__nv_asin" : "mc_math_func_asin", + "__nv_asinhf" : "mc_math_func_asinhf", + "__nv_asinh" : "mc_math_func_asinh", + "__nv_atan2f" : "mc_math_func_atan2f", + "__nv_atan2" : "mc_math_func_atan2", + "__nv_atanf" : "mc_math_func_atanf", + "__nv_atan" : "mc_math_func_atan", + "__nv_atanhf" : "mc_math_func_atanhf", + "__nv_atanh" : "mc_math_func_atanh", + "__nv_erf" : "mc_math_func_erf", + "__nv_erfcf" : "mc_math_func_erfcf", + "__nv_copysignf" : "mc_math_func_copysignf", + "__nv_copysign" : "mc_math_func_copysign", + "__nv_cos" : "mc_math_func_cos", + "__nv_coshf" : "mc_math_func_coshf", + "__nv_cosh" : "mc_math_func_cosh", + "__nv_isnanf" : "mc_math_func_isnanf", + "__nv_isnand" : "mc_math_func_isnan", + "__nv_hypotf" : "mc_math_func_hypotf", + "__nv_hypot" : "mc_math_func_hypot", + "__nv_sqrt" : "mc_math_func_sqrt", + "__nv_rsqrt" : "mc_math_func_rsqrt", + "__nv_nextafterf" : "mc_math_func_nextafterf", + "__nv_nextafter" : "mc_math_func_nextafter", + "__nv_sin" : "mc_math_func_sin", + "__nv_sinhf" : "mc_math_func_sinhf", + "__nv_sinh" : "mc_math_func_sinh", + "__nv_scalbnf" : "mc_math_func_scalbnf", + "__nv_fdiv_rn" : "mc_math_func_fdiv_rn", + "__nv_fdiv_rz" : "mc_math_func_fdiv_rz", + "__nv_powif" : "mc_math_func_powif_inline", + "__nv_finitef" : "mc_math_func_finitef", + "__nv_isfinited" : "mc_math_func_isfinite", + "__nv_fast_fdividef" : "mc_math_func_fdividef", + "__nv_fast_sinf" : "mc_math_func_sinf", + "__nv_fast_cosf" : "mc_math_func_cosf", + "__nv_fast_log2f" : "mc_math_func_log2f", + "__nv_fast_logf" : "mc_math_func_logf", + "__nv_fast_expf" : "mc_math_func_expf", + "__nv_fast_tanf" : "mc_math_func_tanf", + "__nv_fast_exp10f" : "mc_math_func_exp10f", + "__nv_fast_log10f" : "mc_math_func_log10f", + "__nv_fast_powf" : "mc_math_func_powf_inline", + "__nv_rintf" : "mc_math_func_rintf", + "__nv_roundf" : "mc_math_func_roundf", + "__nv_sqrtf" : "mc_math_func_sqrtf", + "__nv_fmaf" : "mc_math_func_fmaf", + } +else: + USE_MACA = False + nv_to_maca_map = {} +assert USE_MACA, "Please set MACA_PATH!" + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return _to_tensor(x, _builder) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_uint32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_uint64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + if hasattr(name, 'value'): + name = name.value + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = f'pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class tensor: + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + @builtin + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + @builtin + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + @builtin + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + if d.value & (d.value - 1) != 0: + raise ValueError(f"Shape element {i} must be a power of 2") + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If no permutation is specified, tries to do a (1,0) permutation, i.e. tries + to transpose a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape ` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + """ + input = _to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be None, and + - `boundary_check` and `padding_option` can be specified to control + the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(dtype, shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + :param keep_dims: if true, keep the reduced dimensions with length 1""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param keep_dims: if true, keep the reduced dimensions with length 1 + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the scan should be done""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param reverse: apply the associative scan in the reverse direction along axis. + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :param num_bins: number of histogram bins + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"{BLOCK_SIZE=}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = nv_to_maca_map[arg_type_symbol_dict[arg_types][0]] if USE_MACA else arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/metax/python/triton/language/extra/__init__.py b/third_party/metax/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..14e1778d2 --- /dev/null +++ b/third_party/metax/python/triton/language/extra/__init__.py @@ -0,0 +1,4 @@ +from . import cuda +from . import hip + +__all__ = ['cuda', 'hip'] diff --git a/third_party/metax/python/triton/language/extra/cuda/__init__.py b/third_party/metax/python/triton/language/extra/cuda/__init__.py new file mode 100644 index 000000000..3ca510e02 --- /dev/null +++ b/third_party/metax/python/triton/language/extra/cuda/__init__.py @@ -0,0 +1,8 @@ +from . import libdevice + +from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +__all__ = [ + "libdevice", "globaltimer", "num_threads", "num_warps", "smid", "convert_custom_float8_sm70", + "convert_custom_float8_sm80" +] diff --git a/third_party/metax/python/triton/language/extra/cuda/libdevice.py b/third_party/metax/python/triton/language/extra/cuda/libdevice.py new file mode 100644 index 000000000..3490e6b0e --- /dev/null +++ b/third_party/metax/python/triton/language/extra/cuda/libdevice.py @@ -0,0 +1,1629 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/metax/python/triton/language/extra/cuda/utils.py b/third_party/metax/python/triton/language/extra/cuda/utils.py new file mode 100644 index 000000000..01bc040b2 --- /dev/null +++ b/third_party/metax/python/triton/language/extra/cuda/utils.py @@ -0,0 +1,109 @@ +from triton.language import core + + +@core.extern +def globaltimer(_builder=None): + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_builder=None): + return core.constexpr(_builder.options.num_warps) + + +# ----- FP8E4M3B15 ------ +# This data-type is a variant of the standard FP8E4M3 format. +# It was designed for fast software conversion to FP16 on +# nvidia GPUs that do not support it natively. +# This is the same format as FP8E4M3Nv, but: +# - the exponent bias is 15 instead of 7 +# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +@core.builtin +def convert_fp8e4b15_to_float16(arg, _builder=None): + return core.inline_asm_elementwise( + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" + "shr.b32 b0, b0, 1; \n" + "add.u32 b1, b1, a1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" + "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): + asm = """{ + .reg .pred p<4>; + .reg .b32 a<2>, b<2>; + .reg .b16 c<4>; + .reg .b16 max_val_f16; + .reg .b32 max_val_f16x2; + mov.b16 max_val_f16, 0x3F00; + mov.b32 max_val_f16x2, 0x3F003F00; + and.b32 a0, $1, 0x7fff7fff; + and.b32 a1, $2, 0x7fff7fff;""" + if has_minx2: + asm += """min.f16x2 a0, a0, max_val_f16x2; + min.f16x2 a1, a1, max_val_f16x2;""" + else: + asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; + setp.lt.f16x2 p2|p3, a1, max_val_f16x2; + mov.b32 {c0, c1}, a0; + mov.b32 {c2, c3}, a1; + selp.b16 c0, c0, max_val_f16, p0; + selp.b16 c1, c1, max_val_f16, p1; + selp.b16 c2, c2, max_val_f16, p2; + selp.b16 c3, c3, max_val_f16, p3; + mov.b32 a0, {c0, c1}; + mov.b32 a1, {c2, c3};""" + asm += """mad.lo.u32 a0, a0, 2, 0x00800080; + mad.lo.u32 a1, a1, 2, 0x00800080; + lop3.b32 b0, $1, 0x80008000, a0, 0xea; + lop3.b32 b1, $2, 0x80008000, a1, 0xea; + prmt.b32 $0, b0, b1, 0x7531; + }""" + return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): + if arg.type.scalar.is_fp8e4b15(): + upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) + if dst_ty.scalar.is_fp32(): + upcast_val = upcast_val.to(core.float32, _builder=_builder) + return upcast_val + + assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() + downcast_val = arg + if arg.type.scalar.is_fp32(): + downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) + downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) + return downcast_val + + +@core.builtin +def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + + +@core.builtin +def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/metax/python/triton/language/extra/hip/__init__.py b/third_party/metax/python/triton/language/extra/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/metax/python/triton/language/extra/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/metax/python/triton/language/extra/hip/libdevice.py b/third_party/metax/python/triton/language/extra/hip/libdevice.py new file mode 100644 index 000000000..02e5d2d0b --- /dev/null +++ b/third_party/metax/python/triton/language/extra/hip/libdevice.py @@ -0,0 +1,468 @@ +from triton.language import core + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")), + (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")), + (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/metax/python/triton/language/extra/libdevice.py b/third_party/metax/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..22d0e61a7 --- /dev/null +++ b/third_party/metax/python/triton/language/extra/libdevice.py @@ -0,0 +1,1214 @@ +from .cuda import libdevice as cuda_libdevice +from .hip import libdevice as hip_libdevice +from triton.language import core +from functools import wraps +from typing import TypeVar + +T = TypeVar('T') + + +def dispatch(fn: T) -> T: + """Dispatch a function to a correct implementation.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + _backend = kwargs["_builder"].options.backend_name + # USE_MACA:maca use cuda entry and map to maca builtin func + if _backend == 'cuda' or _backend == 'maca': + _curr_libdevice_module = cuda_libdevice + elif _backend == 'hip': + _curr_libdevice_module = hip_libdevice + else: + raise RuntimeError('unknown backend') + + try: + _impl = getattr(_curr_libdevice_module, fn.__name__) + except AttributeError: + raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') + + return _impl(*args, **kwargs) + + return wrapper + + +@core.extern +@dispatch +def clz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def popc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def byte_perm(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def mulhi(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul24(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def brev(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sad(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def abs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def floor(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp64h(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ceil(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def trunc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def saturatef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rn(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rz(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rd(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_ru(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fast_dividef(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def add_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hiloint2double(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2loint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2hiint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_int(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_uint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def longlong_as_double(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double_as_longlong(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_sinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_cosf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log2f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_logf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_expf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_tanf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_exp10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_powf(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def hadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ffs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llrint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nearbyint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isnan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def signbit(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def copysign(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def finitef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nextafter(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinpi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cospi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atan2(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def atan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log1p(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def expm1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def norm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def cbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def yn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def jn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcx(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def lgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ldexp(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def scalbn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fmod(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def remainder(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fma(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def pow(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def tgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def round(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llround(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fdim(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def ilogb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def logb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isfinited(arg0, _builder=None): + ... diff --git a/third_party/metax/python/triton/language/math.py b/third_party/metax/python/triton/language/math.py new file mode 100644 index 000000000..de5b5be6b --- /dev/null +++ b/third_party/metax/python/triton/language/math.py @@ -0,0 +1,250 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = core._to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest)") +def div_rn(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + z = core._to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/metax/python/triton/language/random.py b/third_party/metax/python/triton/language/random.py new file mode 100644 index 000000000..430aeb09e --- /dev/null +++ b/third_party/metax/python/triton/language/random.py @@ -0,0 +1,207 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + seed = seed.to(tl.uint64) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/metax/python/triton/language/semantic.py b/third_party/metax/python/triton/language/semantic.py new file mode 100644 index 000000000..8af9790b1 --- /dev/null +++ b/third_party/metax/python/triton/language/semantic.py @@ -0,0 +1,1624 @@ +''' 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Sequence, Tuple, TypeVar + +from .._C.libtriton import ir +from . import core as tl +from . import math + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtne_no_nan': + return ir.ROUNDING_MODE.RTNE_NO_NAN + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz' and 'rtne_no_nan'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, + builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), + _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, type) + + +def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty in [tl.float16, tl.bfloat16] and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16 or bf16") + if element_ty in [tl.int1, tl.int8, tl.int16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + if options.allow_fp8e4b15: + allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] + else: + allowed_types = ['fp8e4nv', 'fp8e5'] + + def _validate_dtype(dtype, allowed_types, operand_name): + if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): + supported_types = ', '.join(allowed_types) + raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") + + _validate_dtype(lhs_dtype, allowed_types, "First operand") + _validate_dtype(rhs_dtype, allowed_types, "Second operand") + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + + assert lhs.type.is_block() and rhs.type.is_block() + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ + and rhs.shape[-1].value >= 16, \ + f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + assert lhs.shape[1].value >= 32, "small blocks not supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/third_party/metax/python/triton/language/standard.py b/third_party/metax/python/triton/language/standard.py new file mode 100644 index 000000000..de30cf260 --- /dev/null +++ b/third_party/metax/python/triton/language/standard.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities (triton metaprogramming sucks) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, core.constexpr) else o + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=True) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major :code:`size_i * size_j` matrix into those + of one where the indices are col-major for each group of :code:`size_g` + rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Creates a tensor of zeros with the same shape and type as a given tensor. + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/metax/python/triton/ops/__init__.py b/third_party/metax/python/triton/ops/__init__.py new file mode 100644 index 000000000..18f1d782d --- /dev/null +++ b/third_party/metax/python/triton/ops/__init__.py @@ -0,0 +1,7 @@ +# from .conv import _conv, conv +from . import blocksparse +from .cross_entropy import _cross_entropy, cross_entropy +from .flash_attention import attention +from .matmul import _matmul, get_higher_dtype, matmul + +__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"] diff --git a/third_party/metax/python/triton/ops/blocksparse/__init__.py b/third_party/metax/python/triton/ops/blocksparse/__init__.py new file mode 100644 index 000000000..6b24b5377 --- /dev/null +++ b/third_party/metax/python/triton/ops/blocksparse/__init__.py @@ -0,0 +1,7 @@ +from .matmul import matmul +from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/third_party/metax/python/triton/ops/blocksparse/matmul.py b/third_party/metax/python/triton/ops/blocksparse/matmul.py new file mode 100644 index 000000000..098e15438 --- /dev/null +++ b/third_party/metax/python/triton/ops/blocksparse/matmul.py @@ -0,0 +1,432 @@ +import torch + +from ... import cdiv, heuristics, jit +from ... import language as tl + +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** + + +@heuristics({ + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, +}) +@jit +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + block_id = tl.program_id(0) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + + # initialize pointers to A + start_am = tl.load(lut + 1) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(lut + 2) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + acc += tl.dot(a, b, out_dtype=tl.float32) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk + c = acc.to(C.dtype.element_ty) + # ---------------- # + # Epilogue # + # ---------------- # + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) + + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + # allocate output + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [c.shape[1], 1, c.shape[0]] + _sdd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # + ) + return c + + +def sdd_lut(layout, block, device): + lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() + return lut, None + + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- + + +@jit +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_n * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + for k in range(K, 0, -TILE_K): + a = tl.load(pa) + b = tl.load(pb) + acc += tl.dot(a, b, out_dtype=tl.float32) + pa += inc_a + pb += inc_b * stride_bk + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) + pc = C \ + + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) + + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out + # meta-parameter heuristics + TILE_N = 128 + # compute output + grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] + _dsd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # + ) + # exit() + return c + + +def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes * step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks, device=layout.device) + else: + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone().long() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + # pad by a factor 2*MAX_NUM_STAGES + # to accommodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, width + + +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +# AB = (B^T A^T)^T + + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + + +############## +# MAIN API # +############## + + +class _matmul(torch.autograd.Function): + + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.db_lut = db_lut + ctx.db_width = db_width + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + da, db = None, None + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) + dout = dc if ctx.has_out else None + return da, db, None, None, None, \ + None, None, None, None, \ + None, None, None, None, None, dout + + +class matmul: + + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + self.block = block + self.mode = mode + self.trans_a = trans_a + self.trans_b = trans_b + self.trans_c = trans_c + self.layout = layout + self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) + + def __call__(self, a, b, out=None): + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) + return c diff --git a/third_party/metax/python/triton/ops/blocksparse/softmax.py b/third_party/metax/python/triton/ops/blocksparse/softmax.py new file mode 100644 index 000000000..bcffff26b --- /dev/null +++ b/third_party/metax/python/triton/ops/blocksparse/softmax.py @@ -0,0 +1,228 @@ +import torch + +from ... import jit +from ... import language as tl +from ... import next_power_of_2 + + +def num_warps(n): + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: + return 4 + if n <= 4096: + return 8 + return 16 + + +@jit +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) + # apply causal mask + out = tl.where((ns > m) & is_causal, -float("inf"), out) + # computation + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) + + +@jit +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + a = tl.where((ns > m) & is_causal & (a == a), 0., a) + da = a * (dout - tl.sum(a * dout, 0)) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) + + +class _softmax(torch.autograd.Function): + + @staticmethod + def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() + # sizes along rows + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block + # offsets in block format + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # block indices + columns = layout.nonzero(as_tuple=False)[:, 2] + header = torch.stack((sizes, offsets), dim=1).view(-1) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) + + @staticmethod + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # + ) + # save to context + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) + ctx.spdims = spdims + ctx.block = block + ctx.maxlut = maxlut + ctx.scale = scale + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout): + # retrieve from context + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) + # run kernel + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # + ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +class softmax: + + def __init__(self, layout, block, device, is_dense=False): + self.spdims = layout.shape + self.layout = layout + self.block = block + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense + + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError(f"relative position embedding must be {a.dtype}") + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) + return a diff --git a/third_party/metax/python/triton/ops/cross_entropy.py b/third_party/metax/python/triton/ops/cross_entropy.py new file mode 100644 index 000000000..88e8dae50 --- /dev/null +++ b/third_party/metax/python/triton/ops/cross_entropy.py @@ -0,0 +1,96 @@ +import torch + +from .. import heuristics, jit +from .. import language as tl +from .. import next_power_of_2 + + +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = tl.load(DPROBS + row) + din = (probs - delta) * dout + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) + + +class _cross_entropy(torch.autograd.Function): + + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) + return neg_logprobs, None + + +cross_entropy = _cross_entropy.apply diff --git a/third_party/metax/python/triton/ops/flash_attention.py b/third_party/metax/python/triton/ops/flash_attention.py new file mode 100644 index 000000000..0825ef26c --- /dev/null +++ b/third_party/metax/python/triton/ops/flash_attention.py @@ -0,0 +1,466 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +Sequence Parallel implementation inspired by HazyResearch +(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) +""" + +import torch +import triton + +from .. import cdiv, jit +from .. import language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + vk_offset = qvk_offset // stride_qm + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, vk_offset), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(vk_offset, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(vk_offset + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + + +@jit +def _bwd_preprocess( + Out, + DO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_m, delta) + + +@jit +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + if CAUSAL: + lo = start_n * BLOCK_M + else: + lo = 0 + + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = off_z * stride_qz + off_h * stride_qh + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + + # initialize row/col offsets + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(Q_block_ptr) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + if CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dv + do = tl.load(DO_block_ptr) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + # compute dq + if not SEQUENCE_PARALLEL: + dq = tl.load(DQ_block_ptr) + dq += tl.dot(ds, k) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + elif SEQUENCE_PARALLEL: + if MMA_V3: + dq = tl.dot(ds, k) + else: + # not work with mma v3, because M % 64 != 0 + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + + # increment pointers + DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) + # write-back + tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) + tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) + + +@jit +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + qk_scale = sm_scale * 1.44269504 + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + DK_block_ptr = tl.make_block_ptr( + base=DK, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DV_block_ptr = tl.make_block_ptr( + base=DV, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + num_block_n = tl.cdiv(N_CTX, BLOCK_N) + if not SEQUENCE_PARALLEL: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + else: + start_n = tl.program_id(1) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=4 # + ) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + + @staticmethod + def backward(ctx, do): + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 + BLOCK = 128 + + if is_hip(): + # Bwd pass runs out of shared memory on HIP with larger block size. + BLOCK = 64 + + q, k, v, o, L = ctx.saved_tensors + sequence_parallel = ctx.sequence_parallel + seq_len_kv = k.shape[2] + do = do.contiguous() + if sequence_parallel: + replicas = cdiv(seq_len_kv, BLOCK) + new_dq_shape = (replicas, ) + q.shape + dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( + o, + do, + delta, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=8, # + num_stages=1 # + ) + + if len(dq.shape) == 5: + dq = dq.sum(dim=0) + return dq, dk, dv, None, None, None + + +attention = _attention.apply diff --git a/third_party/metax/python/triton/ops/matmul.py b/third_party/metax/python/triton/ops/matmul.py new file mode 100644 index 000000000..f7f577a1b --- /dev/null +++ b/third_party/metax/python/triton/ops/matmul.py @@ -0,0 +1,219 @@ +import torch + +from .. import Config, autotune, cdiv, heuristics, jit +from .. import language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@autotune( + configs=[ + # basic configs for compute-bound matmuls + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10, + }, +) +@heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@jit +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32, ), torch.int8: (torch.int32, ) + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: + ab_dtype = None + # launch kernel + grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) + return c + + @staticmethod + def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype) + + +matmul = _matmul.apply diff --git a/third_party/metax/python/triton/ops/matmul_perf_model.py b/third_party/metax/python/triton/ops/matmul_perf_model.py new file mode 100644 index 000000000..b60b74540 --- /dev/null +++ b/third_party/metax/python/triton/ops/matmul_perf_model.py @@ -0,0 +1,171 @@ +import functools +import heapq + +import torch + +from .. import cdiv +from ..runtime import driver +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(['clocks.max.sm'])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/third_party/metax/python/triton/runtime/__init__.py b/third_party/metax/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/metax/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/metax/python/triton/runtime/autotuner.py b/third_party/metax/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..903373ed8 --- /dev/null +++ b/third_party/metax/python/triton/runtime/autotuner.py @@ -0,0 +1,420 @@ +''' 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +from __future__ import annotations + +import builtins +import os +import time +import inspect +import json +import hashlib +import uuid +from typing import Dict + +from .cache import default_cache_dir +from ..testing import do_bench, do_bench_cudagraph +from .jit import KernelInterface +from .errors import OutOfResources + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args, exception: 0 + if pre_hook: + self.pre_hook = pre_hook + elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + elif len(self.restore_idx) > 0: + + def _post_hook(args, exception): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + import torch + self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + def _dump_cache(self, path, config): + cache_dict = dict(**config.all_kwargs()) + os.makedirs(os.path.dirname(path), exist_ok=True) + + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{path}.tmp.pid_{pid}_{rnd_id}" + with open(temp_path, "w", encoding="utf-8") as file: + json.dump(cache_dict, file, ensure_ascii=False, indent=4) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, path) + + def _parse_config_dict(self, config_dict): + default_config_menbers = ("num_warps", "num_stages", "num_ctas", "maxnreg") + kwargs = config_dict.copy() + for menber in default_config_menbers: + kwargs.pop(menber, None) + maxnreg = config_dict["maxnreg"] if "maxnreg" in config_dict else None + return Config(kwargs, config_dict["num_warps"], config_dict["num_stages"], config_dict["num_ctas"], maxnreg) + + def _load_cache(self, path): + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as file: + config_dict = json.load(file) + config = self._parse_config_dict(config_dict) + return config + + def _config_cache_path(self, key): + base_path = os.path.join(default_cache_dir(), "configs") + user_defined_path = os.environ.get("TRITON_AUTOTUNE_CONFIG_PATH") + if user_defined_path is not None: + base_path = user_defined_path + src = inspect.getsource(self.base_fn) + function_hash = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache_file_name = self.base_fn.__name__ + str(key) + ".json" + cache_path = os.path.join(base_path, function_hash, cache_file_name) + return cache_path + + def _enable_config_cache(self): + return os.environ.get("TRITON_ENABLE_PERSISTENT_AUTOTUNE_CONFIGS", "0") == "1" + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + if self._enable_config_cache(): + # load cache from persistent path + config = self._load_cache(self._config_cache_path(key)) + if config is not None: + self.cache[key] = config + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + if self._enable_config_cache(): + self._dump_cache(self._config_cache_path(key), self.cache[key]) + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=25, rep=100, use_cuda_graph=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/metax/python/triton/runtime/build.py b/third_party/metax/python/triton/runtime/build.py new file mode 100644 index 000000000..988076697 --- /dev/null +++ b/third_party/metax/python/triton/runtime/build.py @@ -0,0 +1,83 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + +# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py +def check_env_flag(name: str, default: str = "") -> bool: + return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + if check_env_flag("TRITON_USE_MACA", "ON"): # Default ON + cc_cmd += [f"-D__MACA__", "-lmcruntime"] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/third_party/metax/python/triton/runtime/cache.py b/third_party/metax/python/triton/runtime/cache.py new file mode 100644 index 000000000..bd3c29b99 --- /dev/null +++ b/third_party/metax/python/triton/runtime/cache.py @@ -0,0 +1,281 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import hashlib + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return key diff --git a/third_party/metax/python/triton/runtime/driver.py b/third_party/metax/python/triton/runtime/driver.py new file mode 100644 index 000000000..c3b97a764 --- /dev/null +++ b/third_party/metax/python/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/metax/python/triton/runtime/errors.py b/third_party/metax/python/triton/runtime/errors.py new file mode 100644 index 000000000..4dce91767 --- /dev/null +++ b/third_party/metax/python/triton/runtime/errors.py @@ -0,0 +1,26 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/third_party/metax/python/triton/runtime/interpreter.py b/third_party/metax/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..a82832ecf --- /dev/null +++ b/third_party/metax/python/triton/runtime/interpreter.py @@ -0,0 +1,1127 @@ +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + arch: str = None + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_ptr_to_int(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values): + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message, fileName, funcName, lineNo): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_builtin(lang[0], interpreter_builder) + _patch_builtin(lang[0].tensor, interpreter_builder) + if lang[0] == tl: + _patch_builtin(lang[0].math, interpreter_builder) + _patch_lang_tensor(lang[0].tensor) + _patch_lang_core(lang[0]) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs["grid"] + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + try: + return self.fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/metax/python/triton/runtime/jit.py b/third_party/metax/python/triton/runtime/jit.py new file mode 100644 index 000000000..a12b1d235 --- /dev/null +++ b/third_party/metax/python/triton/runtime/jit.py @@ -0,0 +1,956 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) != ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins # + ): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + + def is_triton_builtin(func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + func = self.visit(node.func) + assert func is None or is_triton_builtin(func) or isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' + + # Traverse arguments as well as node.func so we can find JITFunctions + # passed to tl.reduce or tl.associative_scan as the combine_fn + for obj in itertools.chain( + (func, ), + map(self.visit, node.args), + (self.visit(kw.value) for kw in node.keywords), + ): + if not isinstance(obj, JITFunction): + continue + if is_triton_builtin(obj): + continue + + func_cache_key = obj.cache_key + + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = obj.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + + self.used_global_vals.update(obj.used_global_vals) + + noinline = str(getattr(obj, "noinline", False)) + + key = func_cache_key + noinline + self.hasher.update(key.encode("utf-8")) + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v): + + if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + specialisations.append('compute_spec_key(%s)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return arg.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + from ..compiler import AttrsDescriptor + + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + equal_to_1 = { + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } + # folded equal_to_1 and None + # TODO: method to collect all folded args + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + ): + if JITFunction.cache_hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + } + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + kwargs["debug"] = self.debug + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder() + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + target = driver.active.get_current_target() + backend = self.make_backend(target) + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (self._get_config(*bound_vals), ) + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + + # Check that used global values have not changed. + not_present = object() + for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, + launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import AttrsDescriptor, compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/third_party/metax/python/triton/testing.py b/third_party/metax/python/triton/testing.py new file mode 100644 index 000000000..a54149ffb --- /dev/null +++ b/third_party/metax/python/triton/testing.py @@ -0,0 +1,558 @@ +''' 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + """ + import torch + assert return_mode in ["min", "max", "mean", "median"] + + if torch.cuda.current_stream() == torch.cuda.default_stream(): + raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") + # warmup + fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + times = torch.tensor(ret) + return getattr(torch, return_mode)(times).item() + +# return time (seconds) +# currently only support ns, us, ms & s +def get_time_s(info): + line = info.strip().split("\n")[-1] + # if not grep data return -1 + if "Self CUDA time total:" not in line: + return -1 + str_val = line[22:].strip() + if "ns" in line : + val = float(str_val[:-2]) + val = float(val / 1000000000) + elif "us" in line: + val = float(str_val[:-2]) + val = float(val / 1000000) + elif "ms" in line: + val = float(str_val[:-2]) + val = float(val / 1000) + elif "s" in line: + val = float(str_val[:-1]) + else: + assert 0 + + return val + +def profile(fn, max_retry=10): + import torch + def run(fn): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU], + record_shapes=False, + ) as profiler: + fn() + info = profiler.key_averages(group_by_input_shape=False).table(sort_by="cuda_time_total", max_name_column_width=1000, row_limit=-1) + # print(info) + t = get_time_s(info) + return t + t = -1 + retry = 0 + while t == -1 and retry < max_retry: + t = run(fn) + if t == -1: + print(f"### retry {retry} times to re-run test to get kernel cuda time during run {fn.__name__}") + retry += 1 + assert t != -1, f"torch.profiler cannot grep kernel cuda time after try {max_retry} times during run {fn.__name__}" + return t + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean"): + assert return_mode in ["min", "max", "mean", "median"] + import torch + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + while not end_event.query(): + pass + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + use_profile = os.getenv("TRITON_USE_PROFILER_DO_BENCH", False) + if use_profile: + n_warmup = min(50, n_warmup) + n_repeat = min(30, n_repeat) + else: + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + tt = [] + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + if use_profile: + t = profile(fn) + tt.append(t) + else: + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + if use_profile: + times = torch.tensor(tt, dtype=torch.float) * 1000 # mulple 1000 to convert second to ms + else: + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + color=None, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/metax/python/triton/tools/__init__.py b/third_party/metax/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/metax/python/triton/tools/build_extern.py b/third_party/metax/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/metax/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/metax/python/triton/tools/compile.c b/third_party/metax/python/triton/tools/compile.c new file mode 100644 index 000000000..971bf6191 --- /dev/null +++ b/third_party/metax/python/triton/tools/compile.c @@ -0,0 +1,67 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/third_party/metax/python/triton/tools/compile.h b/third_party/metax/python/triton/tools/compile.h new file mode 100644 index 000000000..d98b7063b --- /dev/null +++ b/third_party/metax/python/triton/tools/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/metax/python/triton/tools/compile.py b/third_party/metax/python/triton/tools/compile.py new file mode 100644 index 000000000..872332b03 --- /dev/null +++ b/third_party/metax/python/triton/tools/compile.py @@ -0,0 +1,145 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/metax/python/triton/tools/disasm.py b/third_party/metax/python/triton/tools/disasm.py new file mode 100644 index 000000000..1e309a2e4 --- /dev/null +++ b/third_party/metax/python/triton/tools/disasm.py @@ -0,0 +1,142 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/metax/python/triton/tools/link.py b/third_party/metax/python/triton/tools/link.py new file mode 100644 index 000000000..75a1157a5 --- /dev/null +++ b/third_party/metax/python/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) From 0e30d6d93b8a82c07d7db45c7614938f65e39864 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 30 May 2025 10:05:23 +0000 Subject: [PATCH 2/6] [BACKEND] Fix soft link backends --- .../metax/python/triton/backends/__init__.py | 50 ------------ .../metax/python/triton/backends/compiler.py | 76 ------------------- .../metax/python/triton/backends/driver.py | 34 --------- 3 files changed, 160 deletions(-) delete mode 100644 third_party/metax/python/triton/backends/__init__.py delete mode 100644 third_party/metax/python/triton/backends/compiler.py delete mode 100644 third_party/metax/python/triton/backends/driver.py diff --git a/third_party/metax/python/triton/backends/__init__.py b/third_party/metax/python/triton/backends/__init__.py deleted file mode 100644 index fbf65d9e9..000000000 --- a/third_party/metax/python/triton/backends/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -import importlib.util -import inspect -from dataclasses import dataclass -from .driver import DriverBase -from .compiler import BaseBackend - - -def _load_module(name, path): - spec = importlib.util.spec_from_file_location(name[:-3], path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _find_concrete_subclasses(module, base_class): - ret = [] - for attr_name in dir(module): - attr = getattr(module, attr_name) - if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): - ret.append(attr) - if len(ret) == 0: - raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") - if len(ret) > 1: - raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") - return ret[0] - - -@dataclass(frozen=True) -class Backend: - compiler: BaseBackend = None - driver: DriverBase = None - - -def _discover_backends(): - backends = dict() - root = os.path.dirname(__file__) - for name in os.listdir(root): - if not os.path.isdir(os.path.join(root, name)): - continue - if name.startswith('__'): - continue - compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) - driver = _load_module(name, os.path.join(root, name, 'driver.py')) - backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), - _find_concrete_subclasses(driver, DriverBase)) - return backends - - -backends = _discover_backends() diff --git a/third_party/metax/python/triton/backends/compiler.py b/third_party/metax/python/triton/backends/compiler.py deleted file mode 100644 index 990690045..000000000 --- a/third_party/metax/python/triton/backends/compiler.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -import re -import subprocess - -from abc import ABCMeta, abstractmethod, abstractclassmethod -from dataclasses import dataclass -from typing import Union - - -@dataclass(frozen=True) -class GPUTarget(object): - # Target backend, e.g., cuda, hip - backend: str - # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) - arch: Union[int, str] - warp_size: int - - -class BaseBackend(metaclass=ABCMeta): - - def __init__(self, target: GPUTarget) -> None: - self.target = target - assert self.supports_target(target) - - @staticmethod - def _path_to_binary(binary: str): - base_dir = os.path.join(os.path.dirname(__file__), os.pardir) - paths = [ - os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), - os.path.join(base_dir, "third_party", "cuda", "bin", binary), - ] - for p in paths: - bin = p.split(" ")[0] - if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) - if result is not None: - version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) - if version is not None: - return p, version.group(1) - raise RuntimeError(f"Cannot find {binary}") - - @abstractclassmethod - def supports_target(target: GPUTarget): - raise NotImplementedError - - @abstractmethod - def hash(self) -> str: - """Returns a unique identifier for this backend""" - raise NotImplementedError - - @abstractmethod - def parse_options(self, options: dict) -> object: - """ - Converts an `options` dictionary into an arbitrary object and returns it. - This function may contain target-specific heuristics and check the legality of the provided options - """ - raise NotImplementedError - - @abstractmethod - def add_stages(self, stages: dict, options: object) -> None: - """ - Populates `stages` dictionary with entries of the form: - ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] - The value of each entry may populate a `metadata` dictionary. - Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. - All stages are expected to return a `str` object, except for the last stage which returns - a `bytes` object for execution by the launcher. - """ - raise NotImplementedError - - @abstractmethod - def load_dialects(self, context): - """ - Load additional MLIR dialects into the provided `context` - """ - raise NotImplementedError diff --git a/third_party/metax/python/triton/backends/driver.py b/third_party/metax/python/triton/backends/driver.py deleted file mode 100644 index e66442943..000000000 --- a/third_party/metax/python/triton/backends/driver.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABCMeta, abstractmethod, abstractclassmethod - - -class DriverBase(metaclass=ABCMeta): - - @abstractclassmethod - def is_active(self): - pass - - @abstractmethod - def get_current_target(self): - pass - - def __init__(self) -> None: - pass - - -class GPUDriver(DriverBase): - - def __init__(self): - # TODO: support other frameworks than torch - import torch - self.get_device_capability = torch.cuda.get_device_capability - try: - from torch._C import _cuda_getCurrentRawStream - self.get_current_stream = _cuda_getCurrentRawStream - except ImportError: - self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - - # TODO: remove once TMA is cleaned up - def assemble_tensormap_to_arg(self, tensormaps_info, args): - return args From 0b4522bba8ed2d209cb1245a79b5b69289f700d6 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 30 May 2025 10:05:58 +0000 Subject: [PATCH 3/6] [BACKEND] Fix soft link backends --- third_party/metax/python/triton/backends | 1 + 1 file changed, 1 insertion(+) create mode 120000 third_party/metax/python/triton/backends diff --git a/third_party/metax/python/triton/backends b/third_party/metax/python/triton/backends new file mode 120000 index 000000000..13a83a85c --- /dev/null +++ b/third_party/metax/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends \ No newline at end of file From c7d3d75c206c82b35a1497730d35461079f3592a Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 30 May 2025 18:13:28 +0800 Subject: [PATCH 4/6] [BACKEND] Fix metax code format --- CMakeLists.txt | 2 +- third_party/metax/CMakeLists.txt | 2 +- third_party/metax/backend/compiler.py | 29 +-- third_party/metax/backend/driver.c | 19 +- third_party/metax/backend/driver.py | 3 + third_party/metax/bin/CMakeLists.txt | 2 +- .../metax/bin/RegisterTritonDialects.h | 19 +- third_party/metax/include/CMakeLists.txt | 2 +- .../metax/include/triton/CMakeLists.txt | 2 +- third_party/metax/lib/CMakeLists.txt | 2 +- .../python/triton/_C/include/CMakeLists.txt | 2 +- .../triton/_C/include/triton/CMakeLists.txt | 2 +- .../metax/python/triton/language/core.py | 200 +++++++++--------- .../metax/python/triton/language/semantic.py | 3 +- .../metax/python/triton/runtime/autotuner.py | 6 +- .../metax/python/triton/runtime/build.py | 4 +- third_party/metax/python/triton/testing.py | 26 +-- 17 files changed, 173 insertions(+), 152 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 92cbecf11..5ad333575 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -329,7 +329,7 @@ if(TRITON_BUILD_PYTHON_MODULE) set(TRITON_LIBRARIES ${triton_libs} ${triton_plugins} - + # mlir MLIRMACADialect MLIRGPUToMACATransforms diff --git a/third_party/metax/CMakeLists.txt b/third_party/metax/CMakeLists.txt index 55efc4cc6..5a39b450f 100644 --- a/third_party/metax/CMakeLists.txt +++ b/third_party/metax/CMakeLists.txt @@ -21,4 +21,4 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() endif() -add_subdirectory(bin) \ No newline at end of file +add_subdirectory(bin) diff --git a/third_party/metax/backend/compiler.py b/third_party/metax/backend/compiler.py index 3e5a13808..b301b6a34 100644 --- a/third_party/metax/backend/compiler.py +++ b/third_party/metax/backend/compiler.py @@ -57,7 +57,7 @@ def ptx_get_version(cuda_version) -> int: def file_hash(path): with open(path, "rb") as f: return hashlib.sha256(f.read()).hexdigest() - + def maca_get_kernel_name(src: str) -> str: ''' @@ -71,6 +71,7 @@ def maca_get_kernel_name(src: str) -> str: if line.startswith('define metaxgpu_kernel void @'): return re.match(r"define metaxgpu_kernel void @(.+?)\(", line).groups()[0] + def parse_option(string): return [item for item in string.split(';') if item] @@ -107,7 +108,8 @@ def __post_init__(self): if not extern_libs.get('libdevice', None): # ext_maca_mathlib.bc env_ext_libdevice_path = os.getenv("TRITON_EXT_LIBDEVICE_PATH", None) - ext_libdevice_path = env_ext_libdevice_path if env_ext_libdevice_path is not None else str(ext_default_libdir) + '/ext_maca_mathlib.bc' + ext_libdevice_path = env_ext_libdevice_path if env_ext_libdevice_path is not None else str( + ext_default_libdir) + '/ext_maca_mathlib.bc' assert os.path.exists(ext_libdevice_path), "ext_maca_mathlib.bc do not exit, please check!" extern_libs['ext_libdevice'] = ext_libdevice_path # maca_kernellib.bc @@ -171,7 +173,6 @@ def get_codegen_implementation(self): def load_dialects(self, ctx): metax.load_dialects(ctx) - @staticmethod def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) @@ -191,7 +192,7 @@ def make_ttir(mod, metadata, opt): def make_ttgir(mod, metadata, opt, capability): assert opt.pipeline_load_num >= -1, "invalid pipeline_load_num value!" scenarios = parse_option(opt.scenario) - disable_prefetch = "unprefetch" in scenarios + disable_prefetch = "unprefetch" in scenarios fullstage = "fullstage" in scenarios store_coalesce = "storeCoalesce" in scenarios mla = "mla" in scenarios @@ -208,8 +209,8 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_f32_dot_tc(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) - - if opt.pipeline == "cpasync" : + + if opt.pipeline == "cpasync": disable_prefetch = True metax.passes.ttgpuir.add_accelerate_matmul(pm, opt.num_stages, disable_prefetch, store_coalesce, "c500") passes.ttgpuir.add_remove_layout_conversions(pm) @@ -225,15 +226,18 @@ def make_ttgir(mod, metadata, opt, capability): if opt.pipeline == "basic": if mla and single_shm: # only mla=True and single_shm=True - metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, True) + metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, + True) else: - metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, False) + metax.passes.ttgpuir.add_pipeline_maca(pm, opt.num_stages, opt.pipeline_load_num, fullstage, + False) elif opt.pipeline == "cpasync" and not mla: metax.passes.ttgpuir.add_pipeline_async_tn(pm, opt.num_stages) metax.passes.ttgpuir.add_pipeline_async_tt(pm, opt.num_stages) metax.passes.ttgpuir.add_pipeline_async_base(pm, opt.num_stages, fullstage) elif mla and opt.num_stages == 2 and opt.pipeline == "cpasync": - metax.passes.ttgpuir.add_pipeline_async_multidot_mla(pm, opt.num_stages, fullstage, opt.pipeline_load_num) + metax.passes.ttgpuir.add_pipeline_async_multidot_mla(pm, opt.num_stages, fullstage, + opt.pipeline_load_num) else: print("no avalilable pipeline for maca") else: @@ -299,7 +303,6 @@ def make_llir(src, metadata, options, capability): metadata["name"] = maca_get_kernel_name(llir) return llir - @staticmethod def make_mcfatbin(src, metadata, opt, capability): scenarios = parse_option(opt.scenario) @@ -312,16 +315,17 @@ def make_mcfatbin(src, metadata, opt, capability): compile_options = "" if (opt.pipeline == "basic" or opt.pipeline == "basic-prefetch") and "mla" not in scenarios: compile_options = " -mllvm -metaxgpu-sched-regpressure=false -mllvm -metaxgpu-PostRA-Scheduler=false -mllvm -metaxgpu-mma-sched=true " - if "fullstage" in scenarios: + if "fullstage" in scenarios: compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup " else: compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-sched-mma-maxnum=3 " - if "roll" not in scenarios: + if "roll" not in scenarios: compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " elif opt.pipeline == "cpasync" and "mla" not in scenarios: compile_options = " -mllvm -metaxgpu-sched-regpressure=true -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true \ -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true -mllvm -metaxgpu-shl-add-combine=false \ -mllvm -misched-postra=true -mllvm -enable-post-misched=true " + if os.getenv("TRITON_ENABLE_MACA_COMPILER_INT8_OPT"): compile_options += " -mllvm -metaxgpu-slp-vectorize-i8=true" if "unroll" in scenarios: @@ -329,7 +333,6 @@ def make_mcfatbin(src, metadata, opt, capability): if opt.extra_options != "": compile_options = opt.extra_options return metax.translate_llvmir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH'), compile_options) - def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) diff --git a/third_party/metax/backend/driver.c b/third_party/metax/backend/driver.c index 8b417b529..28706cd8e 100644 --- a/third_party/metax/backend/driver.c +++ b/third_party/metax/backend/driver.c @@ -1,6 +1,7 @@ -/* Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. */ -#include +/* Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All + * Rights Reserved. */ #include +#include #include #define PY_SSIZE_T_CLEAN #include @@ -59,7 +60,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { mcDeviceGet(&device, device_id); // create a struct to hold device properties - int max_shared_mem = 64 * 1024; // 64KB, no CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN + int max_shared_mem = + 64 * + 1024; // 64KB, no CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN int max_num_regs; int multiprocessor_count; int warp_size = 64; @@ -70,8 +73,8 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { &max_num_regs, mcDeviceAttributeMaxSharedMemoryPerBlock, device)); MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( &multiprocessor_count, mcDeviceAttributeMultiProcessorCount, device)); - MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( - &sm_clock_rate, mcDeviceAttributeClockRate, device)); + MACA_CHECK_AND_RETURN_NULL( + mcDeviceGetAttribute(&sm_clock_rate, mcDeviceAttributeClockRate, device)); MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( &mem_clock_rate, mcDeviceAttributeMemoryClockRate, device)); MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( @@ -110,7 +113,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxSetCurrent(pctx)); } MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleLoadData(&mod, data)); - MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleGetFunction(&fun, mod, name)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcModuleGetFunction(&fun, mod, name)); // get allocated registers and spilled registers from the function MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( mcFuncGetAttribute(&n_regs, MC_FUNC_ATTRIBUTE_NUM_REGS, fun)); @@ -123,7 +127,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { if (PyErr_Occurred()) { return NULL; } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills); + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); } static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { diff --git a/third_party/metax/backend/driver.py b/third_party/metax/backend/driver.py index 000f3302a..a096632c0 100644 --- a/third_party/metax/backend/driver.py +++ b/third_party/metax/backend/driver.py @@ -16,15 +16,18 @@ # libraries = ['cuda'] libraries = [] + @functools.lru_cache() def maca_home_dirs(): return os.getenv("MACA_PATH") + @functools.lru_cache() def libmaca_dirs(): maca_path = maca_home_dirs() return ["{}/lib/".format(maca_path)] + maca_lib_dir = libmaca_dirs() maca_include_dir = [os.path.join(maca_home_dirs(), "include")] diff --git a/third_party/metax/bin/CMakeLists.txt b/third_party/metax/bin/CMakeLists.txt index b3d7981ae..3268dccd2 100644 --- a/third_party/metax/bin/CMakeLists.txt +++ b/third_party/metax/bin/CMakeLists.txt @@ -83,4 +83,4 @@ target_link_libraries(triton-llvm-opt PRIVATE export_executable_symbols_for_plugins(triton-llvm-opt) set_target_properties(triton-llvm-opt PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ -) \ No newline at end of file +) diff --git a/third_party/metax/bin/RegisterTritonDialects.h b/third_party/metax/bin/RegisterTritonDialects.h index d002ce597..f9fc64b40 100644 --- a/third_party/metax/bin/RegisterTritonDialects.h +++ b/third_party/metax/bin/RegisterTritonDialects.h @@ -1,18 +1,21 @@ -/* 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. */ +/* 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights + * Reserved. */ #pragma once -#include "triton/Target/LLVMIR/Passes.h" #include "mlir/InitAllPasses.h" #include "python/src/plugin.h" +#include "triton/Target/LLVMIR/Passes.h" using BackendRegisterFunc = void (*)(); -BackendRegisterFunc load_backend_register_func(const char *backend_name, const char *func_name) { +BackendRegisterFunc load_backend_register_func(const char *backend_name, + const char *func_name) { void *symbol = load_backend_symbol(backend_name, func_name); return reinterpret_cast(symbol); } -using DialectRegisterFunc = void (*)(mlir::DialectRegistry*); -DialectRegisterFunc load_dialect_register_func(const char *backend_name, const char *func_name) { +using DialectRegisterFunc = void (*)(mlir::DialectRegistry *); +DialectRegisterFunc load_dialect_register_func(const char *backend_name, + const char *func_name) { void *symbol = load_backend_symbol(backend_name, func_name); return reinterpret_cast(symbol); } @@ -20,9 +23,11 @@ DialectRegisterFunc load_dialect_register_func(const char *backend_name, const c inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); - auto registerAllTritonPasses = load_backend_register_func("metax", "registerAllTritonPasses"); + auto registerAllTritonPasses = + load_backend_register_func("metax", "registerAllTritonPasses"); registerAllTritonPasses(); - auto registerConvertTritonGPUToLLVMPass = load_backend_register_func("metax", "registerConvertTritonGPUToLLVMPass"); + auto registerConvertTritonGPUToLLVMPass = + load_backend_register_func("metax", "registerConvertTritonGPUToLLVMPass"); registerConvertTritonGPUToLLVMPass(); auto registerDialect = load_dialect_register_func("metax", "registerDialect"); diff --git a/third_party/metax/include/CMakeLists.txt b/third_party/metax/include/CMakeLists.txt index 72181b98f..109c292fe 100644 --- a/third_party/metax/include/CMakeLists.txt +++ b/third_party/metax/include/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(triton) \ No newline at end of file +add_subdirectory(triton) diff --git a/third_party/metax/include/triton/CMakeLists.txt b/third_party/metax/include/triton/CMakeLists.txt index 7369bfadf..310cf4abc 100644 --- a/third_party/metax/include/triton/CMakeLists.txt +++ b/third_party/metax/include/triton/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(Target) \ No newline at end of file +add_subdirectory(Target) diff --git a/third_party/metax/lib/CMakeLists.txt b/third_party/metax/lib/CMakeLists.txt index 7369bfadf..310cf4abc 100644 --- a/third_party/metax/lib/CMakeLists.txt +++ b/third_party/metax/lib/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(Target) \ No newline at end of file +add_subdirectory(Target) diff --git a/third_party/metax/python/triton/_C/include/CMakeLists.txt b/third_party/metax/python/triton/_C/include/CMakeLists.txt index 72181b98f..109c292fe 100644 --- a/third_party/metax/python/triton/_C/include/CMakeLists.txt +++ b/third_party/metax/python/triton/_C/include/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(triton) \ No newline at end of file +add_subdirectory(triton) diff --git a/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt b/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt index 7369bfadf..310cf4abc 100644 --- a/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt +++ b/third_party/metax/python/triton/_C/include/triton/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(Target) \ No newline at end of file +add_subdirectory(Target) diff --git a/third_party/metax/python/triton/language/core.py b/third_party/metax/python/triton/language/core.py index a343a2887..8957628a3 100644 --- a/third_party/metax/python/triton/language/core.py +++ b/third_party/metax/python/triton/language/core.py @@ -24,107 +24,107 @@ PropagateNan = ir.PROPAGATE_NAN if os.getenv("MACA_PATH") is not None: - USE_MACA = True - # map all the libdevice.10.bc funcs to maca in maca_mathlib.bc, do not modify the specific func in libdevice.py - # TODO(MACA): add python/triton/language/extra/maca/ directory later - nv_to_maca_map = { - "__nv_floorf" : "mc_math_func_floorf", - "__nv_floor" : "mc_math_func_floor", - "__nv_log2f" : "mc_math_func_log2f", - "__nv_log2" : "mc_math_func_log2", - "__nv_logf" : "mc_math_func_logf", - "__nv_powf" : "mc_math_func_powf_inline", - "__nv_pow" : "mc_math_func_pow", - "__nv_norm4df" : "mc_math_func_norm4df", - "__nv_norm4d" : "mc_math_func_norm4d", - "__nv_expf" : "mc_math_func_expf", - "__nv_exp" : "mc_math_func_exp", - "__nv_ffs" : "mc_math_func_ffs", - "__nv_umulhi" : "mc_math_func_umulhi", - "__nv_umul64hi" : "mc_math_func_umul64hi", - "__nv_rsqrtf" : "mc_math_func_rsqrtf", - "__nv_erff" : "mc_math_func_erff", - "__nv_tanhf" : "mc_math_func_tanhf", - "__nv_max" : "mc_math_func_max", - "__nv_umax" : "mc_math_func_umax", - "__nv_fmaxf" : "mc_math_func_fmaxf", - "__nv_fmax" : "mc_math_func_fmax", - "__nv_llmax" : "mc_math_func_llmax", - "__nv_ullmax" : "mc_math_func_ullmax", - "__nv_min" : "mc_math_func_min", - "__nv_umin" : "mc_math_func_umin", - "__nv_fminf" : "mc_math_func_fminf", - "__nv_fmin" : "mc_math_func_fmin", - "__nv_llmin" : "mc_math_func_llmin", - "__nv_ullmin" : "mc_math_func_ullmin", - "__nv_isinff" : "mc_math_func_isinff", - "__nv_log1pf" : "mc_math_func_log1pf", - "__nv_truncf" : "mc_math_func_truncf", - "__nv_expm1f" : "mc_math_func_expm1f", - "__nv_exp2f" : "mc_math_func_exp2f", - "__nv_fmodf" : "mc_math_func_fmodf", - "__nv_lgammaf" : "mc_math_func_lgammaf", - "__nv_log" : "mc_math_func_log", - "__nv_nearbyintf" : "mc_math_func_nearbyintf", - "__nv_signbitf" : "mc_math_func_signbitf", - "__nv_tanf" : "mc_math_func_tanf", - "__nv_ceilf" : "mc_math_func_ceilf", - "__nv_acosf" : "mc_math_func_acosf", - "__nv_acoshf" : "mc_math_func_acoshf", - "__nv_acos" : "mc_math_func_acos", - "__nv_acosh" : "mc_math_func_acosh", - "__nv_asinf" : "mc_math_func_asinf", - "__nv_asin" : "mc_math_func_asin", - "__nv_asinhf" : "mc_math_func_asinhf", - "__nv_asinh" : "mc_math_func_asinh", - "__nv_atan2f" : "mc_math_func_atan2f", - "__nv_atan2" : "mc_math_func_atan2", - "__nv_atanf" : "mc_math_func_atanf", - "__nv_atan" : "mc_math_func_atan", - "__nv_atanhf" : "mc_math_func_atanhf", - "__nv_atanh" : "mc_math_func_atanh", - "__nv_erf" : "mc_math_func_erf", - "__nv_erfcf" : "mc_math_func_erfcf", - "__nv_copysignf" : "mc_math_func_copysignf", - "__nv_copysign" : "mc_math_func_copysign", - "__nv_cos" : "mc_math_func_cos", - "__nv_coshf" : "mc_math_func_coshf", - "__nv_cosh" : "mc_math_func_cosh", - "__nv_isnanf" : "mc_math_func_isnanf", - "__nv_isnand" : "mc_math_func_isnan", - "__nv_hypotf" : "mc_math_func_hypotf", - "__nv_hypot" : "mc_math_func_hypot", - "__nv_sqrt" : "mc_math_func_sqrt", - "__nv_rsqrt" : "mc_math_func_rsqrt", - "__nv_nextafterf" : "mc_math_func_nextafterf", - "__nv_nextafter" : "mc_math_func_nextafter", - "__nv_sin" : "mc_math_func_sin", - "__nv_sinhf" : "mc_math_func_sinhf", - "__nv_sinh" : "mc_math_func_sinh", - "__nv_scalbnf" : "mc_math_func_scalbnf", - "__nv_fdiv_rn" : "mc_math_func_fdiv_rn", - "__nv_fdiv_rz" : "mc_math_func_fdiv_rz", - "__nv_powif" : "mc_math_func_powif_inline", - "__nv_finitef" : "mc_math_func_finitef", - "__nv_isfinited" : "mc_math_func_isfinite", - "__nv_fast_fdividef" : "mc_math_func_fdividef", - "__nv_fast_sinf" : "mc_math_func_sinf", - "__nv_fast_cosf" : "mc_math_func_cosf", - "__nv_fast_log2f" : "mc_math_func_log2f", - "__nv_fast_logf" : "mc_math_func_logf", - "__nv_fast_expf" : "mc_math_func_expf", - "__nv_fast_tanf" : "mc_math_func_tanf", - "__nv_fast_exp10f" : "mc_math_func_exp10f", - "__nv_fast_log10f" : "mc_math_func_log10f", - "__nv_fast_powf" : "mc_math_func_powf_inline", - "__nv_rintf" : "mc_math_func_rintf", - "__nv_roundf" : "mc_math_func_roundf", - "__nv_sqrtf" : "mc_math_func_sqrtf", - "__nv_fmaf" : "mc_math_func_fmaf", - } + USE_MACA = True + # map all the libdevice.10.bc funcs to maca in maca_mathlib.bc, do not modify the specific func in libdevice.py + # TODO(MACA): add python/triton/language/extra/maca/ directory later + nv_to_maca_map = { + "__nv_floorf": "mc_math_func_floorf", + "__nv_floor": "mc_math_func_floor", + "__nv_log2f": "mc_math_func_log2f", + "__nv_log2": "mc_math_func_log2", + "__nv_logf": "mc_math_func_logf", + "__nv_powf": "mc_math_func_powf_inline", + "__nv_pow": "mc_math_func_pow", + "__nv_norm4df": "mc_math_func_norm4df", + "__nv_norm4d": "mc_math_func_norm4d", + "__nv_expf": "mc_math_func_expf", + "__nv_exp": "mc_math_func_exp", + "__nv_ffs": "mc_math_func_ffs", + "__nv_umulhi": "mc_math_func_umulhi", + "__nv_umul64hi": "mc_math_func_umul64hi", + "__nv_rsqrtf": "mc_math_func_rsqrtf", + "__nv_erff": "mc_math_func_erff", + "__nv_tanhf": "mc_math_func_tanhf", + "__nv_max": "mc_math_func_max", + "__nv_umax": "mc_math_func_umax", + "__nv_fmaxf": "mc_math_func_fmaxf", + "__nv_fmax": "mc_math_func_fmax", + "__nv_llmax": "mc_math_func_llmax", + "__nv_ullmax": "mc_math_func_ullmax", + "__nv_min": "mc_math_func_min", + "__nv_umin": "mc_math_func_umin", + "__nv_fminf": "mc_math_func_fminf", + "__nv_fmin": "mc_math_func_fmin", + "__nv_llmin": "mc_math_func_llmin", + "__nv_ullmin": "mc_math_func_ullmin", + "__nv_isinff": "mc_math_func_isinff", + "__nv_log1pf": "mc_math_func_log1pf", + "__nv_truncf": "mc_math_func_truncf", + "__nv_expm1f": "mc_math_func_expm1f", + "__nv_exp2f": "mc_math_func_exp2f", + "__nv_fmodf": "mc_math_func_fmodf", + "__nv_lgammaf": "mc_math_func_lgammaf", + "__nv_log": "mc_math_func_log", + "__nv_nearbyintf": "mc_math_func_nearbyintf", + "__nv_signbitf": "mc_math_func_signbitf", + "__nv_tanf": "mc_math_func_tanf", + "__nv_ceilf": "mc_math_func_ceilf", + "__nv_acosf": "mc_math_func_acosf", + "__nv_acoshf": "mc_math_func_acoshf", + "__nv_acos": "mc_math_func_acos", + "__nv_acosh": "mc_math_func_acosh", + "__nv_asinf": "mc_math_func_asinf", + "__nv_asin": "mc_math_func_asin", + "__nv_asinhf": "mc_math_func_asinhf", + "__nv_asinh": "mc_math_func_asinh", + "__nv_atan2f": "mc_math_func_atan2f", + "__nv_atan2": "mc_math_func_atan2", + "__nv_atanf": "mc_math_func_atanf", + "__nv_atan": "mc_math_func_atan", + "__nv_atanhf": "mc_math_func_atanhf", + "__nv_atanh": "mc_math_func_atanh", + "__nv_erf": "mc_math_func_erf", + "__nv_erfcf": "mc_math_func_erfcf", + "__nv_copysignf": "mc_math_func_copysignf", + "__nv_copysign": "mc_math_func_copysign", + "__nv_cos": "mc_math_func_cos", + "__nv_coshf": "mc_math_func_coshf", + "__nv_cosh": "mc_math_func_cosh", + "__nv_isnanf": "mc_math_func_isnanf", + "__nv_isnand": "mc_math_func_isnan", + "__nv_hypotf": "mc_math_func_hypotf", + "__nv_hypot": "mc_math_func_hypot", + "__nv_sqrt": "mc_math_func_sqrt", + "__nv_rsqrt": "mc_math_func_rsqrt", + "__nv_nextafterf": "mc_math_func_nextafterf", + "__nv_nextafter": "mc_math_func_nextafter", + "__nv_sin": "mc_math_func_sin", + "__nv_sinhf": "mc_math_func_sinhf", + "__nv_sinh": "mc_math_func_sinh", + "__nv_scalbnf": "mc_math_func_scalbnf", + "__nv_fdiv_rn": "mc_math_func_fdiv_rn", + "__nv_fdiv_rz": "mc_math_func_fdiv_rz", + "__nv_powif": "mc_math_func_powif_inline", + "__nv_finitef": "mc_math_func_finitef", + "__nv_isfinited": "mc_math_func_isfinite", + "__nv_fast_fdividef": "mc_math_func_fdividef", + "__nv_fast_sinf": "mc_math_func_sinf", + "__nv_fast_cosf": "mc_math_func_cosf", + "__nv_fast_log2f": "mc_math_func_log2f", + "__nv_fast_logf": "mc_math_func_logf", + "__nv_fast_expf": "mc_math_func_expf", + "__nv_fast_tanf": "mc_math_func_tanf", + "__nv_fast_exp10f": "mc_math_func_exp10f", + "__nv_fast_log10f": "mc_math_func_log10f", + "__nv_fast_powf": "mc_math_func_powf_inline", + "__nv_rintf": "mc_math_func_rintf", + "__nv_roundf": "mc_math_func_roundf", + "__nv_sqrtf": "mc_math_func_sqrtf", + "__nv_fmaf": "mc_math_func_fmaf", + } else: - USE_MACA = False - nv_to_maca_map = {} + USE_MACA = False + nv_to_maca_map = {} assert USE_MACA, "Please set MACA_PATH!" diff --git a/third_party/metax/python/triton/language/semantic.py b/third_party/metax/python/triton/language/semantic.py index 8af9790b1..81006fb9f 100644 --- a/third_party/metax/python/triton/language/semantic.py +++ b/third_party/metax/python/triton/language/semantic.py @@ -708,7 +708,8 @@ def _str_to_rounding_mode(rounding_mode: Optional[str]): return ir.ROUNDING_MODE.RTNE_NO_NAN if rounding_mode == 'rtz': return ir.ROUNDING_MODE.RTZ - raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz' and 'rtne_no_nan'.") + raise ValueError( + f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz' and 'rtne_no_nan'.") def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: diff --git a/third_party/metax/python/triton/runtime/autotuner.py b/third_party/metax/python/triton/runtime/autotuner.py index 903373ed8..da4c9152b 100644 --- a/third_party/metax/python/triton/runtime/autotuner.py +++ b/third_party/metax/python/triton/runtime/autotuner.py @@ -138,7 +138,7 @@ def kernel_call(): return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure): return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] - + def _dump_cache(self, path, config): cache_dict = dict(**config.all_kwargs()) os.makedirs(os.path.dirname(path), exist_ok=True) @@ -170,7 +170,7 @@ def _load_cache(self, path): config_dict = json.load(file) config = self._parse_config_dict(config_dict) return config - + def _config_cache_path(self, key): base_path = os.path.join(default_cache_dir(), "configs") user_defined_path = os.environ.get("TRITON_AUTOTUNE_CONFIG_PATH") @@ -205,7 +205,7 @@ def run(self, *args, **kwargs): config = self._load_cache(self._config_cache_path(key)) if config is not None: self.cache[key] = config - if key not in self.cache: + if key not in self.cache: # prune configs used_cached_result = False pruned_configs = self.prune_configs(kwargs) diff --git a/third_party/metax/python/triton/runtime/build.py b/third_party/metax/python/triton/runtime/build.py index 988076697..4c6829d2b 100644 --- a/third_party/metax/python/triton/runtime/build.py +++ b/third_party/metax/python/triton/runtime/build.py @@ -17,10 +17,12 @@ def quiet(): finally: sys.stdout, sys.stderr = old_stdout, old_stderr + # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) @@ -44,7 +46,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): scheme = 'posix_prefix' py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] include_dirs = include_dirs + [srcdir, py_include_dir] - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] if check_env_flag("TRITON_USE_MACA", "ON"): # Default ON cc_cmd += [f"-D__MACA__", "-lmcruntime"] cc_cmd += [f'-l{lib}' for lib in libraries] diff --git a/third_party/metax/python/triton/testing.py b/third_party/metax/python/triton/testing.py index a54149ffb..e8ef0dc5b 100644 --- a/third_party/metax/python/triton/testing.py +++ b/third_party/metax/python/triton/testing.py @@ -79,6 +79,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): times = torch.tensor(ret) return getattr(torch, return_mode)(times).item() + # return time (seconds) # currently only support ns, us, ms & s def get_time_s(info): @@ -87,7 +88,7 @@ def get_time_s(info): if "Self CUDA time total:" not in line: return -1 str_val = line[22:].strip() - if "ns" in line : + if "ns" in line: val = float(str_val[:-2]) val = float(val / 1000000000) elif "us" in line: @@ -100,21 +101,25 @@ def get_time_s(info): val = float(str_val[:-1]) else: assert 0 - + return val + def profile(fn, max_retry=10): import torch + def run(fn): with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU], - record_shapes=False, - ) as profiler: + activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU], + record_shapes=False, + ) as profiler: fn() - info = profiler.key_averages(group_by_input_shape=False).table(sort_by="cuda_time_total", max_name_column_width=1000, row_limit=-1) + info = profiler.key_averages(group_by_input_shape=False).table(sort_by="cuda_time_total", + max_name_column_width=1000, row_limit=-1) # print(info) - t = get_time_s(info) + t = get_time_s(info) return t + t = -1 retry = 0 while t == -1 and retry < max_retry: @@ -126,10 +131,7 @@ def run(fn): return t -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, - quantiles=None, - fast_flush=True, - return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): assert return_mode in ["min", "max", "mean", "median"] import torch """ @@ -208,7 +210,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, end_event[i].record() # Record clocks torch.cuda.synchronize() - if use_profile: + if use_profile: times = torch.tensor(tt, dtype=torch.float) * 1000 # mulple 1000 to convert second to ms else: times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) From 5d752a1a7116a73b5350e1980083e86ee4a1df5b Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 30 May 2025 12:26:34 +0000 Subject: [PATCH 5/6] [BACKEND] [TEST] Add metax python unit test --- .github/workflows/metax-build-and-test.yml | 39 +- .../metax/python/test/unit/conftest.py | 12 + .../python/test/unit/language/conftest.py | 5 + .../test/unit/language/test_annotations.py | 49 + .../test/unit/language/test_block_pointer.py | 100 + .../test/unit/language/test_compile_errors.py | 304 + .../test/unit/language/test_conversions.py | 356 ++ .../python/test/unit/language/test_core.py | 5417 +++++++++++++++++ .../test/unit/language/test_decorator.py | 48 + .../test/unit/language/test_line_info.py | 171 + .../python/test/unit/language/test_random.py | 255 + .../test/unit/language/test_reproducer.py | 42 + .../test/unit/language/test_standard.py | 75 + .../test/unit/language/test_subprocess.py | 162 + .../python/test/unit/operators/conftest.py | 5 + .../test/unit/operators/test_blocksparse.py | 237 + .../test/unit/operators/test_cross_entropy.py | 41 + .../unit/operators/test_flash_attention.py | 118 + .../test/unit/operators/test_inductor.py | 198 + .../python/test/unit/operators/test_matmul.py | 202 + .../test/unit/runtime/test_autotuner.py | 132 + .../python/test/unit/runtime/test_bindings.py | 81 + .../python/test/unit/runtime/test_cache.py | 534 ++ .../python/test/unit/runtime/test_driver.py | 14 + .../python/test/unit/runtime/test_jit.py | 42 + .../python/test/unit/runtime/test_launch.py | 134 + .../python/test/unit/runtime/test_subproc.py | 73 + 27 files changed, 8842 insertions(+), 4 deletions(-) create mode 100644 third_party/metax/python/test/unit/conftest.py create mode 100644 third_party/metax/python/test/unit/language/conftest.py create mode 100644 third_party/metax/python/test/unit/language/test_annotations.py create mode 100644 third_party/metax/python/test/unit/language/test_block_pointer.py create mode 100644 third_party/metax/python/test/unit/language/test_compile_errors.py create mode 100644 third_party/metax/python/test/unit/language/test_conversions.py create mode 100644 third_party/metax/python/test/unit/language/test_core.py create mode 100644 third_party/metax/python/test/unit/language/test_decorator.py create mode 100644 third_party/metax/python/test/unit/language/test_line_info.py create mode 100644 third_party/metax/python/test/unit/language/test_random.py create mode 100644 third_party/metax/python/test/unit/language/test_reproducer.py create mode 100644 third_party/metax/python/test/unit/language/test_standard.py create mode 100644 third_party/metax/python/test/unit/language/test_subprocess.py create mode 100644 third_party/metax/python/test/unit/operators/conftest.py create mode 100644 third_party/metax/python/test/unit/operators/test_blocksparse.py create mode 100644 third_party/metax/python/test/unit/operators/test_cross_entropy.py create mode 100644 third_party/metax/python/test/unit/operators/test_flash_attention.py create mode 100644 third_party/metax/python/test/unit/operators/test_inductor.py create mode 100644 third_party/metax/python/test/unit/operators/test_matmul.py create mode 100644 third_party/metax/python/test/unit/runtime/test_autotuner.py create mode 100644 third_party/metax/python/test/unit/runtime/test_bindings.py create mode 100644 third_party/metax/python/test/unit/runtime/test_cache.py create mode 100644 third_party/metax/python/test/unit/runtime/test_driver.py create mode 100644 third_party/metax/python/test/unit/runtime/test_jit.py create mode 100644 third_party/metax/python/test/unit/runtime/test_launch.py create mode 100644 third_party/metax/python/test/unit/runtime/test_subproc.py diff --git a/.github/workflows/metax-build-and-test.yml b/.github/workflows/metax-build-and-test.yml index 3dc7622f5..c57001eab 100644 --- a/.github/workflows/metax-build-and-test.yml +++ b/.github/workflows/metax-build-and-test.yml @@ -1,7 +1,10 @@ name: Metax-Build-And-Test on: - workflow_call: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -12,8 +15,36 @@ jobs: runs-on: metax if: ${{ github.repository == 'FlagTree/flagtree' }} steps: - - name: Checkout code + - name: Checkout code (attempt 1) + id: checkout1 uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before checkout2 + if: steps.checkout1.outcome == 'failure' + run: | + echo "First checkout attempt failed. Sleeping for 120 seconds before retry..." + sleep 120 + + - name: Checkout code (attempt 2) + id: checkout2 + if: steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before final checkout + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + run: | + echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..." + sleep 180 + + - name: Checkout code (final attempt) + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + + - name: Verify checkout success + if: success() + run: echo "Checkout completed successfully" - name: FlagTree Build on Metax shell: bash @@ -21,9 +52,9 @@ jobs: source ~/env.sh export FLAGTREE_BACKEND=metax cd python - MAX_JOBS=20 pip3 install . --no-build-isolation + MAX_JOBS=32 pip3.10 install . --no-build-isolation - name: FlagTree Test on Metax shell: bash run: | - pytest -s python/test/unit + python3.10 -m pytest -s third_party/metax/python/test/unit diff --git a/third_party/metax/python/test/unit/conftest.py b/third_party/metax/python/test/unit/conftest.py new file mode 100644 index 000000000..7a02d322b --- /dev/null +++ b/third_party/metax/python/test/unit/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/third_party/metax/python/test/unit/language/conftest.py b/third_party/metax/python/test/unit/language/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/metax/python/test/unit/language/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/metax/python/test/unit/language/test_annotations.py b/third_party/metax/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..0c1f065a1 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_annotations.py @@ -0,0 +1,49 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/metax/python/test/unit/language/test_block_pointer.py b/third_party/metax/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..c932131c9 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,100 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): + pid = tl.program_id(0) + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # + (dtypes_str, n, padding) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("float16", "float16"), ("int16", "float16")) + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) +def test_block_copy(dtypes_str, n, padding_option, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[0] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + if src_dtype_str in ("bool", "int16"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + a.to(dst_dtype) + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + else: + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/metax/python/test/unit/language/test_compile_errors.py b/third_party/metax/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..0531f8ebc --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,304 @@ +import pytest + +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "is not defined" in str(e.value), "error should mention the undefined variable" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 0" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the static_assert call" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the `not`" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 1.0" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "at 2:4:" in str(inner), "error should point to xyz" + assert "" not in str(inner) + + assert "at 3:4" in str(outer), "error should point to the nested_call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + + assert "at 2:4:" in str(outer), "error should point to expand_dims call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) diff --git a/third_party/metax/python/test/unit/language/test_conversions.py b/third_party/metax/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..35e807860 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_conversions.py @@ -0,0 +1,356 @@ +# fmt: off + + +import os +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(dst, dst2)) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + ('float8e5', 'float16'), + ('float8e5', 'bfloat16'), + ('float8e5', 'float32'), + + ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + ('float8e4b15', 'float32'), + + ('float8e4nv', 'float16'), + ('float8e4nv', 'bfloat16'), + ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") + + if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): + pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") + + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): + pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + ('float32', 'float8e5', 'rtne', 0x47600000), + ('float32', 'float8e5', 'rtz', 0x47600000), + ('float32', 'float8e4nv', 'rtne', 0x43e00000), + ('float32', 'float8e4b8', 'rtne', 0x43700000), + ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + ('bfloat16', 'float8e5', 'rtne', 0x4760), + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + ('float16', 'float8e5', 'rtne', 0x7b00), + ('float16', 'float8e4nv', 'rtne', 0x5f00), + + ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + ('float16', 'float8e5b16', 'rtne', 0x7b00), + ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/metax/python/test/unit/language/test_core.py b/third_party/metax/python/test/unit/language/test_core.py new file mode 100644 index 000000000..8c76ca3d0 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_core.py @@ -0,0 +1,5417 @@ +# flake8: noqa: F821,F841 +import itertools +import re +from typing import Optional, Union +import math +import textwrap +import tempfile + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.runtime.jit import TensorWrapper, reinterpret + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "hip" + + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, warps_per_cta): + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +def is_layout_applicable(layout) -> bool: + common_layouts = [BlockedLayout, SharedLayout] + if layout in common_layouts: + return True + elif is_cuda(): + return isinstance(layout, MmaLayout) + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None, test_broadcast=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01) + + do_test(x, y, kernel) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'bfloat16'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ('uint16', 'bfloat16'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'bfloat16'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'bfloat16'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), + ] + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match="Not equal to tolerance"): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%")) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = 0 + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, device): + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + if AXIS == 1: + tl.atomic_add(Z + off0, z) + else: + tl.atomic_add(Z + off1, z) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis, keepdims=False) + # triton result + x_tri = to_triton(x, device=device) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4, sem) + assert (torch.equal(X, Y)) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 3 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 3 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + assert z.unique().size(0) == z.size(0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(dtype_str, num_ctas, device): + check_type_supported(dtype_str, device) + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = GENERATE_TEST_HERE + tl.store(output_ptr + offsets, output, mask=mask) + + triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + +@pytest.mark.skip(reason="metax todo") +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = numpy_op(x) + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not suppoted before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = z.astype(np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan_layouts = [ + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * NUM_PID_N + pid_n, local) + # the following segfaults AMD backend following #3492 + # really unclear why; the llvm-ir and kernel arguments are + # identical ! + # tl.store(Y + off_m * tl.num_programs(1) + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +def test_scan_layouts(M, N, src_layout, axis, device): + + ir = f""" + #blocked = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32 + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 32, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(warps_per_cta=[2, 2]), + WmmaLayout(warps_per_cta=[4, 1]), + WmmaLayout(warps_per_cta=[1, 4]), +] + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [64, 64], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if epilogue_kind == 'expand_reduce2d' and isinstance(src_layout, MmaLayout): + pytest.skip( + "Currently MmaLayout combined with slice encoding and reduce op trigger device illegal memory access") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout[2] = 16 if dtype_str == "float16" else 8 + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + num_warps = src_layout.warps_per_cta[0] * src_layout.warps_per_cta[1] + if num_warps == 8: + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [4], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) + %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device): + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis): + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip() and shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=8) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + assert "Unsupported float8 dtype" + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", + [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in ["ieee" if is_hip() else "tf32"] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')] + for kpack in [1, 2 if is_hip() else 1]] + [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1) + for col_a in [True, False] + for col_b in [True, False]] + + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1)] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) + for float8_type in ["float8e5", "float8e4nv"]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): + pytest.skip("float8e4nv and float8e5 not supported on HIP") + if is_hip() and (input_precision != "ieee"): + pytest.skip(f"{input_precision} not supported on HIP") + if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64): + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): + if not is_cuda(): + pass + else: + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync.bfly") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (capability[0] >= 9): + assert "shared" not in red_code + assert "bar.sync" not in red_code + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + if in_dtype == 'float32' and input_precision != "ieee": + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("B", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("M, N, K", [(64, 64, 64), (32, 32, 32)]) +@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]) +def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): + if is_hip(): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_M, BLOCK_N = 32, 32 + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +def test_max_num_imprecise_acc(device): + + if not hasattr(torch, 'float8_e5m2'): + pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") + + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability != (9, 0): + return + + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + MAX_NUM_IMPRECISE_ACC: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) + y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) + z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) + z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) + tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) + + M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 + x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) + y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) + z = torch.zeros((M, N), dtype=torch.float32, device=device) + h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) + if not is_cuda(): + return + assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + # When using MMAv3, we will not pipeline the load op for Y, as the loaded + # value is in rowmajor. But MMAv3 requires its second operand is in colmajor + # because transpose is not supported for MMAv3 with float32 input. + if capability[0] >= 9: + assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + else: + assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with an intermate copy to shared memory run. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + if not is_cuda(): + return + + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.empty(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + # np.testing.assert_allclose(dst, src[:N]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + if not is_cuda(): + return + pgm = _kernel[(1, )](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + spec_type = specialization.signature["VALUE"] + return f"kernel_{spec_type}" + + @triton.jit(repr=repr) + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, x) + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +def test_trans_reshape(device): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +@pytest.mark.skip(reason="metax todo") +def test_globaltimer(device): + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +@pytest.mark.skip(reason="metax todo") +def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", layouts) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): + if (M == 1 or N == 1) and interm_layout: + # TODO(jlebar): These OOB accesses don't even hit an assert in the + # compiler, and some of them return the wrong result instead of + # crashing! + pytest.skip("Out of bound access when maxPhase > 1") + if str(src_layout) == str(dst_layout): + pytest.skip() + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + + %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + # Mma -> mma support is TODO on Hopper (and Volta) + # [ + # MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], +] + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", mma_pairs) +def test_convertmma2mma(M, N, mma_pair, dtype, device): + if is_hip(): + pytest.skip("test_mma2mma is not supported in HIP") + + src_layout, _ = mma_pair + num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_pipeline_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + if is_hip(): + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, device): + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +def test_tl_range(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_pipeline_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.skip(reason="metax todo") +def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" + if is_hip(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # re-use the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() diff --git a/third_party/metax/python/test/unit/language/test_decorator.py b/third_party/metax/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..66371ba60 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_decorator.py @@ -0,0 +1,48 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/metax/python/test/unit/language/test_line_info.py b/third_party/metax/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..6421c7309 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_line_info.py @@ -0,0 +1,171 @@ +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def kernel_single(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + from triton.backends.nvidia.compiler import _path_to_binary + nvdisasm, _ = _path_to_binary("nvdisasm") + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1: + if file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 15)) + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 28)) + assert (check_file_lines(file_lines, "test_line_info.py", 21)) + assert (check_file_lines(file_lines, "test_line_info.py", 30)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 42)) + assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 36)) + assert (check_file_lines(file_lines, "test_line_info.py", 37)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 53)) + assert (check_file_lines(file_lines, "test_line_info.py", 54)) + assert (check_file_lines(file_lines, "test_line_info.py", 55)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 65)) + assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) diff --git a/third_party/metax/python/test/unit/language/test_random.py b/third_party/metax/python/test/unit/language/test_random.py new file mode 100644 index 000000000..e0e59b069 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_random.py @@ -0,0 +1,255 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK: tl.constexpr = 1024 + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/metax/python/test/unit/language/test_reproducer.py b/third_party/metax/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..a045e8f30 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_reproducer.py @@ -0,0 +1,42 @@ +import os +import shutil + +import pytest + +import torch +import triton +import re + + +@triton.jit +def triton_(): + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/metax/python/test/unit/language/test_standard.py b/third_party/metax/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..017ff36f8 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_standard.py @@ -0,0 +1,75 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + sort_kernel[(1, )](x, z, N, M, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (1, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, N, M, num_warps=8) + assert (y == z).all(), (y, z) diff --git a/third_party/metax/python/test/unit/language/test_subprocess.py b/third_party/metax/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..b309e74b4 --- /dev/null +++ b/third_party/metax/python/test/unit/language/test_subprocess.py @@ -0,0 +1,162 @@ +import itertools +import os +import subprocess +import sys +from collections import Counter + +import pytest + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +assert_path = os.path.join(dir_path, "assert_helper.py") + +# TODO: bfloat16 after LLVM-15 +assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] +nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] +torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), +]) +def test_print(func_type: str, data_type: str): + proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + outs, err = proc.communicate() + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert err == b'' + return + + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("func_type", assert_types) +def test_assert(func_type: str): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + os.environ["TRITON_DEBUG"] = "1" + proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + + # Check for segfaults. + assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) + + os.environ["TRITON_DEBUG"] = "0" + if func_type == "static_assert" or func_type == "device_assert_passes": + assert num_errs == 0 + else: + assert num_errs == N - 1 + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize("caller_type, callee_type", nested_types) +def test_assert_nested(caller_type, callee_type): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + if caller_type == "none": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 + elif caller_type == "true": + if callee_type == "false": + assert num_errs == 0 + else: + assert num_errs == N - 1 + elif caller_type == "false": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 diff --git a/third_party/metax/python/test/unit/operators/conftest.py b/third_party/metax/python/test/unit/operators/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/metax/python/test/unit/operators/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/metax/python/test/unit/operators/test_blocksparse.py b/third_party/metax/python/test/unit/operators/test_blocksparse.py new file mode 100644 index 000000000..0980ca14e --- /dev/null +++ b/third_party/metax/python/test/unit/operators/test_blocksparse.py @@ -0,0 +1,237 @@ +import pytest +import torch + +import triton +import triton.ops + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + return ret + + +def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32): + if data is None: + data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().to(dtype) + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + +def mask_tensor(x, mask, block, value=0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + return ret + + +@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256): + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: mask_tensor(x, layout, BLOCK) + # create inputs + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) + shape = { + "sdd": (M, N), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), + }[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE) + b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE) + dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.retain_grad() + b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad + # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.retain_grad() + b_tri.retain_grad() + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device) + c_tri = op(a_tri, b_tri) + c_tri.backward(dc_tri) + da_tri = a_tri.grad + db_tri = b_tri.grad + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + + # compare + torch.testing.assert_close(c_ref, c_tri, **tol) + torch.testing.assert_close(da_ref, da_tri, **tol) + torch.testing.assert_close(db_ref, db_tri, **tol) + + +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout + # make sure each row has at least one non-zero element + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = make_pair(a_shape) + dout_ref, dout_tri = make_pair(a_shape) + # compute [torch] + a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device=device) + if is_causal: + at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = sparsify_tensor(out_ref, layout, BLOCK) + da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad + # compare + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_attention_fwd_bwd( + block, + dtype, + device, + input_scale=1.0, + scale=1 / 8.0, + n_ctx=256, + batch_size=2, + n_heads=2, +): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) + ] + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + query.retain_grad() + key.retain_grad() + value.retain_grad() + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) + # ad hoc loss + loss = (attn_out**2).mean() + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device=device, dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out**2).mean() + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + for g1, g2 in zip(grads, torch_grads): + torch.testing.assert_close(g1, g2, **tol) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +def triton_attention( + layout, + block: int, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, + device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, + device=value.device) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, is_causal=True) + a = sparse_dot_dsd_nn(w, value) + return a diff --git a/third_party/metax/python/test/unit/operators/test_cross_entropy.py b/third_party/metax/python/test/unit/operators/test_cross_entropy.py new file mode 100644 index 000000000..7033549ff --- /dev/null +++ b/third_party/metax/python/test/unit/operators/test_cross_entropy.py @@ -0,0 +1,41 @@ +import pytest +import torch + +import triton +import triton.ops + + +@pytest.mark.parametrize("M, N, dtype, mode", [ # + (M, N, dtype, mode) + for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] +]) +def test_op(M, N, dtype, mode, device): + capability = torch.cuda.get_device_capability() + if capability[0] < 8 and dtype == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] + # create inputs + x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device=device) + # forward pass + tt_y = triton.ops.cross_entropy(x, idx) + th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) + if mode == 'forward': + torch.testing.assert_close(th_y, tt_y) + # backward pass + elif mode == 'backward': + dy = torch.randn_like(tt_y) + # triton backward + tt_y.backward(dy) + tt_dx = x.grad.clone() + # torch backward + x.grad = None + th_y.backward(dy) + th_dx = x.grad.clone() + if dtype == torch.float16: + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/third_party/metax/python/test/unit/operators/test_flash_attention.py b/third_party/metax/python/test/unit/operators/test_flash_attention.py new file mode 100644 index 000000000..f7c9081d6 --- /dev/null +++ b/third_party/metax/python/test/unit/operators/test_flash_attention.py @@ -0,0 +1,118 @@ +import pytest +import torch +import os + +import triton +import triton.ops + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.interpreter +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128), +]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('seq_par', [True, False]) +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Flash attention only supported for compute capability >= 80") + if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1": + pytest.skip("Flash attention bfloat16 not supported in interpreter mode") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device=device)) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(dtype) + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_out), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_out), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dv), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dv), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dk), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dk), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dq), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dq), dim=0), atol=atol, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par, + }) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + sm_scale = 1.3 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if provider == "triton": + fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/third_party/metax/python/test/unit/operators/test_inductor.py b/third_party/metax/python/test/unit/operators/test_inductor.py new file mode 100644 index 000000000..a638cb633 --- /dev/null +++ b/third_party/metax/python/test/unit/operators/test_inductor.py @@ -0,0 +1,198 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +def test_normalization_with_remat(device): + + @triton.jit + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr): + xnumel = 512 + rnumel = 4096 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x3 = xindex + x0 = xindex % 64 + tmp1 = tl.load(in_ptr0 + (x0), xmask) + tmp3 = tl.load(in_ptr1 + (x0), xmask) + tmp11 = tl.load(in_ptr2 + (x0), xmask) + tmp13 = tl.load(in_ptr3 + (x0), xmask) + _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r2 = rindex + tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0) + tmp2 = tmp0 - tmp1 + tmp4 = 1e-05 + tmp5 = tmp3 + tmp4 + tmp6 = tl.sqrt(tmp5) + tmp7 = 1 / tmp6 + tmp8 = 1.0 + tmp9 = tmp7 * tmp8 + tmp10 = tmp2 * tmp9 + tmp12 = tmp10 * tmp11 + tmp14 = tmp12 + tmp13 + _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17) + tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask) + tmp17 = tl.sum(_tmp17, 1)[:, None] + tmp18 = 4096.0 + tmp19 = tmp17 / tmp18 + tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) + + torch.manual_seed(123) + + buf14 = torch.rand(8, 64, 64, 64, device=device) + buf16 = torch.rand(8, 1, 64, device=device) + arg114_1 = torch.rand(64, device=device) + arg115_1 = torch.rand(64, device=device) + arg8_1 = torch.rand(64, device=device) + arg9_1 = torch.rand(64, device=device) + triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + + +def test_avg_pool_bw(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x1 = (xindex // 8) % 8 + x0 = xindex % 8 + x2 = (xindex // 64) + x5 = xindex + tmp0 = (-1) + x1 + tmp1 = (-1) + x0 + tmp2 = 2 + x1 + tmp3 = 2 + x0 + tmp4 = 0 + tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4)) + tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4)) + tmp7 = 8 + tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7)) + tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7)) + tmp10 = tmp5 + tmp4 + tmp11 = tmp6 + tmp4 + tmp12 = 1 + tmp13 = tmp8 - tmp12 + tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13)) + tmp15 = tmp9 - tmp12 + tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15)) + tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp18 = tmp17 / 9 + tmp19 = tmp10 < tmp8 + tmp20 = tmp11 < tmp9 + tmp21 = tmp19 & tmp20 + tmp22 = 0.0 + tmp23 = tl.where(tmp21, tmp18, tmp22) + tmp24 = tmp6 + tmp12 + tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15)) + tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp27 = tmp26 / 9 + tmp28 = tmp24 < tmp9 + tmp29 = tmp19 & tmp28 + tmp30 = tmp23 + tmp27 + tmp31 = tl.where(tmp29, tmp30, tmp23) + tmp32 = 2 + tmp33 = tmp6 + tmp32 + tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15)) + tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp36 = tmp35 / 9 + tmp37 = tmp33 < tmp9 + tmp38 = tmp19 & tmp37 + tmp39 = tmp31 + tmp36 + tmp40 = tl.where(tmp38, tmp39, tmp31) + tmp41 = tmp5 + tmp12 + tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13)) + tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp44 = tmp43 / 9 + tmp45 = tmp41 < tmp8 + tmp46 = tmp45 & tmp20 + tmp47 = tmp40 + tmp44 + tmp48 = tl.where(tmp46, tmp47, tmp40) + tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp50 = tmp49 / 9 + tmp51 = tmp45 & tmp28 + tmp52 = tmp48 + tmp50 + tmp53 = tl.where(tmp51, tmp52, tmp48) + tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp55 = tmp54 / 9 + tmp56 = tmp45 & tmp37 + tmp57 = tmp53 + tmp55 + tmp58 = tl.where(tmp56, tmp57, tmp53) + tmp59 = tmp5 + tmp32 + tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13)) + tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp62 = tmp61 / 9 + tmp63 = tmp59 < tmp8 + tmp64 = tmp63 & tmp20 + tmp65 = tmp58 + tmp62 + tmp66 = tl.where(tmp64, tmp65, tmp58) + tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp68 = tmp67 / 9 + tmp69 = tmp63 & tmp28 + tmp70 = tmp66 + tmp68 + tmp71 = tl.where(tmp69, tmp70, tmp66) + tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp73 = tmp72 / 9 + tmp74 = tmp63 & tmp37 + tmp75 = tmp71 + tmp73 + tmp76 = tl.where(tmp74, tmp75, tmp71) + tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None) + + inp = torch.ones(8, 2048, 8, 8, device=device, dtype=torch.half) + out = torch.ones_like(inp) * 3 + numel = inp.numel() + triton_[(numel // 1024, )](inp, out, 1024) + out_ref = torch.ones_like(inp) + out_ref[:, :, 1:7, 0::7] = 2 / 3 + out_ref[:, :, 0::7, 1:7] = 2 / 3 + out_ref[:, :, 0::7, 0::7] = 4 / 9 + torch.testing.assert_close(out, out_ref) + + +@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("num_warps", [1, 4]) +def test_scan2d_broadcast(RBLOCK, num_warps, device): + + @triton.jit(debug=True) + def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + rindex = tl.arange(0, RBLOCK)[None, :] + xindex = tl.arange(0, XBLOCK)[:, None] + data = tl.load(in_ptr + rindex) + scan = tl.cumsum(data, 1) + expected_max = tl.sum(data, 1) + tl.device_assert(scan <= expected_max) + tl.store(out_ptr + xindex * RBLOCK + rindex, scan) + + XBLOCK = 4 + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device=device) + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device=device) + fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) + torch.testing.assert_close(output, ref) + + +def test_scan2d_for(device): + + @triton.jit + def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): + rbase = tl.arange(0, RBLOCK)[None, :] + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + tmp3 = tl.where(rmask, 1, 0) + tmp6 = tl.cumsum(tmp3, 1) + tl.store(out_ptr0 + rindex, tmp6, rmask) + + RBLOCK = 8 + out0 = torch.empty(RBLOCK, device=device, dtype=torch.int64) + fn[(1, )](out0, RBLOCK, RBLOCK) + ref = torch.arange(RBLOCK, device=device, dtype=torch.int64) + 1 + torch.testing.assert_close(out0, ref) diff --git a/third_party/metax/python/test/unit/operators/test_matmul.py b/third_party/metax/python/test/unit/operators/test_matmul.py new file mode 100644 index 000000000..315a5411d --- /dev/null +++ b/third_party/metax/python/test/unit/operators/test_matmul.py @@ -0,0 +1,202 @@ +import itertools + +import pytest +import torch + +import triton +import triton.language as tl +import triton.ops + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@pytest.mark.skip(reason="metax todo") +@pytest.mark.parametrize( + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", + itertools.chain( + *[[ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], + # n-stage + *[[ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] + for DTYPE in ["float16", "bfloat16", "float32"] + for AT in [False, True] + for BT in [False, True] + for STAGES in [4]], + # tf32x3 + *[[ + (16, 16, 16, 1, 1, 2, 32, 32, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (64, 32, 64, 1, 2, 2, 128, 64, 128, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 64, 16, 1, 4, 2, 256, 128, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (256, 128, 32, 1, 8, 2, 512, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 128, 32, 1, 4, 2, 256, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + ] for AT in [False, True] for BT in [False, True]], + # mixed-precision + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("int8", "bfloat16"), + ("float16", "int8"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], + # mixed-precision block layout + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True]], + # acc-out-dtype and output_dtype + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]], + ), +) +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, + F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"): + pytest.skip("Only test bfloat16 on devices with sm >= 80") + if capability[0] < 9 and (ADTYPE == "float8e4nv" or BDTYPE == "float8e4nv"): + pytest.skip("Only test float8e4nv on devices with sm >= 90") + if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1: + pytest.skip("bfloat16 matmuls don't allow split_k for now") + torch.manual_seed(0) + # nuke kernel decorators -- will set meta-parameters manually + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] + kernel = triton.ops._matmul.kernel + kernel.configs = configs + # kernel.run = kernel.run.run.run + + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + + def is_fp8(dtype): + return "float8" in dtype + + def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty_strided(x.shape, x.stride(), dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + def upcast_if_fp8(x, dtype): + if is_fp8(dtype): + return f8_to_f16(x, dtype) + return x + + def init_input(m, n, dtype, acc_dtype): + if 'float8' in dtype: + ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] + sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 + val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth + return sign | val + if dtype == "int8": + return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) + # Use small range of values to prevent numerical issues. + min_exp = -4 if acc_dtype == "float16" else -10 + exponents = torch.randint(min_exp, 0, size=(m, n)) + ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda") + return ret + + if is_hip(): + if INPUT_PRECISION == 'tf32x3' or is_fp8(ADTYPE) or is_fp8(BDTYPE): + pytest.skip("fp8 inputs or tf32x3 precison does not have native support on hip") + # allocate/transpose inputs + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) + a = a if not AT else a.T.contiguous().T + b = b if not BT else b.T.contiguous().T + # run test + th_a = upcast_if_fp8(a, ADTYPE) + th_b = upcast_if_fp8(b, BDTYPE) + ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) + try: + if is_fp8(ADTYPE): + a = triton.reinterpret(a, getattr(tl, ADTYPE)) + if is_fp8(BDTYPE): + b = triton.reinterpret(b, getattr(tl, BDTYPE)) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, INPUT_PRECISION, F8_FASTACCUM, output_dtype) + torch.testing.assert_close(th_c, tt_c) + except triton.OutOfResources as e: + pytest.skip(str(e)) diff --git a/third_party/metax/python/test/unit/runtime/test_autotuner.py b/third_party/metax/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..6bbff1227 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,132 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N) + _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +def test_hooks(): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tunning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] diff --git a/third_party/metax/python/test/unit/runtime/test_bindings.py b/third_party/metax/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..c48ba9b4a --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(): + """ + Test the MLIR bindings exposed for the out-ot-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device="cuda"), # in_ptr0 + torch.empty((32, 32), device="cuda"), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device="cuda"), # out_ptr + 16, # BLOCK_SIZE + ] + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs}, + constants={i: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=kernel._get_config(*args, ), + ) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + codegen_fns = dict() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module.walk(walk_fn) diff --git a/third_party/metax/python/test/unit/runtime/test_cache.py b/third_party/metax/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..eddfe06e0 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_cache.py @@ -0,0 +1,534 @@ +import importlib.util +import itertools +import os +import shutil +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction + +tmpdir = ".tmp" + + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new): + kernel.hash = None + function_1.hash = None + function_2.hash = None + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn.src = orig_combine_fn_src.replace("COMBINE_OP", combine_op) + kernel_with_combine_fn.src = orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn.src = orig_combine_fn_src + kernel_with_combine_fn.src = orig_kernel_src + + kernel_with_combine_fn.hash = None + combine_fn.hash = None + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(code, num_extra_lines): + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + f.write(('# extra line\n' * num_extra_lines) + code) + f.flush() + spec = importlib.util.spec_from_file_location("module.name", f.name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + orig_mod = write_and_load_module(code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + updated_mod = write_and_load_module(code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty + shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_reuse(): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 3, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + + device = torch.cuda.current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.cache[device]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = torch.cuda.current_device() + assert len(kernel.cache[device]) == 1 + + +GLOBAL_VAR: tl.constexpr = 1 + + +def test_kernel_global_var_change(): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL: tl.constexpr = 42 + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = 42 + kernel[(1, )]() + CONSTEXPR_GLOBAL = 43 + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL: tl.constexpr = 0 + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_constexpr_not_callable() -> None: + + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True + + +def test_jit_warmup_cache() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + + +def test_jit_debug() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.debug = False + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + kernel_add.debug = True + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 3 + bins = list(kernel_add.cache[device].values()) + assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline() -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = torch.cuda.current_device() + assert len(kernel_add_device.cache[device]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.cache[device].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +def test_preload() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = torch.cuda.current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + reset_tmp_dir() + kernel_add.cache[device].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.cache[device]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.cache[device]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) diff --git a/third_party/metax/python/test/unit/runtime/test_driver.py b/third_party/metax/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..de00082f5 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_driver.py @@ -0,0 +1,14 @@ +import sys + +import triton + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/third_party/metax/python/test/unit/runtime/test_jit.py b/third_party/metax/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..5892494c4 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_jit.py @@ -0,0 +1,42 @@ +import itertools +import pytest +import torch + +import triton +import triton.language as tl + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/metax/python/test/unit/runtime/test_launch.py b/third_party/metax/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..f17c05674 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_launch.py @@ -0,0 +1,134 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch + +import triton +import triton.language as tl + +# from typing import Tuple + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device='cuda') + out = torch.randn(10, device='cuda') + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/metax/python/test/unit/runtime/test_subproc.py b/third_party/metax/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..333d1f929 --- /dev/null +++ b/third_party/metax/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,73 @@ +import multiprocessing +import os +import shutil + +import torch + +import triton +import triton.language as tl +from triton.compiler import ASTSource + +tmpdir = ".tmp" + +target = triton.runtime.driver.active.get_current_target() + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def compile_fn(attrs, capability): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constants={3: 32}, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, + ) + triton.compile(src=src, target=target) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(attrs, capability): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc.start() + proc.join() + assert proc.exitcode == 0 From 04879fca21e2bf0b3101fc4ffd48aca79abd2cf3 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 30 May 2025 12:34:35 +0000 Subject: [PATCH 6/6] [BACKEND] [TEST] Add metax python unit test --- .github/workflows/metax-build-and-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/metax-build-and-test.yml b/.github/workflows/metax-build-and-test.yml index c57001eab..f5ed4ff53 100644 --- a/.github/workflows/metax-build-and-test.yml +++ b/.github/workflows/metax-build-and-test.yml @@ -52,7 +52,7 @@ jobs: source ~/env.sh export FLAGTREE_BACKEND=metax cd python - MAX_JOBS=32 pip3.10 install . --no-build-isolation + MAX_JOBS=32 python3.10 -m pip install . --no-build-isolation - name: FlagTree Test on Metax shell: bash