diff --git a/.github/workflows/tsingmicro-build-and-test.yml b/.github/workflows/tsingmicro-build-and-test.yml index 171a93e6f..935d4ac1b 100644 --- a/.github/workflows/tsingmicro-build-and-test.yml +++ b/.github/workflows/tsingmicro-build-and-test.yml @@ -48,6 +48,7 @@ jobs: - name: FlagTree Build on Tsingmicro shell: bash run: | + pip uninstall -y triton source ~/env.sh export FLAGTREE_BACKEND=tsingmicro cd python @@ -59,4 +60,3 @@ jobs: source ~/env.sh python3.11 -c 'import triton; print(triton.__path__)' /usr/local/lib/python3.11/dist-packages/triton/backends/tsingmicro/bin/tsingmicro-opt --version - /usr/local/lib/python3.11/dist-packages/triton/backends/tsingmicro/bin/tsingmicro-llvm-opt --version diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py index f9a40a0b7..8fa5c1263 100644 --- a/python/setup_tools/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -439,10 +439,10 @@ def check_env(env_val): # tsingmicro cache.store( - file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64", + file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.11-x64", condition=("tsingmicro" == flagtree_backend), url= - "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64.tar.gz", + "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.11-x64.tar.gz", pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt index 1e32c0b57..c6b2388a9 100644 --- a/third_party/tsingmicro/CMakeLists.txt +++ b/third_party/tsingmicro/CMakeLists.txt @@ -1,18 +1,19 @@ -if(NOT DEFINED TX8_HOME) - if(DEFINED ENV{TX8_HOME}) - set(TX8_HOME $ENV{TX8_HOME}) +if(NOT DEFINED TX8_DEPS_ROOT) + if(DEFINED ENV{TX8_DEPS_ROOT}) + set(TX8_DEPS_ROOT $ENV{TX8_DEPS_ROOT}) else() - message(FATAL_ERROR "TX8_HOME environment variable is not defined") + message(FATAL_ERROR "TX8_DEPS_ROOT environment variable is not defined") endif() endif() -set(TSM_BACKEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/backend) set(XUANTIE_NAME Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2) +set(INSTALL_TSINGMICRO_DIR ${CMAKE_INSTALL_PREFIX}/triton/backends/tsingmicro/) +install(CODE "file(MAKE_DIRECTORY \"${INSTALL_TSINGMICRO_DIR}\")") include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/crt/include) -include_directories(${TX8_HOME}/include) +include_directories(${TX8_DEPS_ROOT}/include) add_subdirectory(include) add_subdirectory(lib) add_subdirectory(bin) @@ -23,7 +24,7 @@ if(TRITON_BUILD_PYTHON_MODULE) LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR Tx81IR TritonTilingExtIR TritonStructuredIR TritonToCoreDialects TritonToLinalg TritonToStructured StructuredToMemref LinalgToMagicKernel - TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81) + TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81 LLVMRISCVCodeGen LLVMRISCVAsmParser) target_link_libraries(TritonTsingMicro PRIVATE Python3::Module pybind11::headers) endif() #if(TRITON_BUILD_UT) diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py index 5a60dd12d..805783542 100644 --- a/third_party/tsingmicro/backend/compiler.py +++ b/third_party/tsingmicro/backend/compiler.py @@ -25,10 +25,10 @@ def _get_llvm_bin_path(bin_name: str) -> str: return os.path.join(path, "bin", bin_name) -def _get_tx8_path(sub_name: str) -> str: - path = os.getenv("TX8_HOME", "") +def _get_tx8_deps_path(sub_name: str) -> str: + path = os.getenv("TX8_DEPS_ROOT", "") if path == "": - raise Exception("TX8_HOME is not set.") + raise Exception("TX8_DEPS_ROOT is not set.") return os.path.join(path, sub_name) @@ -55,22 +55,36 @@ def compile_accelerator(): # FIXME: Hardcoded path #dst_path = os.path.join(tmpdir, f"{name}.so") dst_path = "/tmp/kernel.so" - libc_lib = os.path.join(_get_tx8_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf", + libc_lib = os.path.join(_get_tx8_deps_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf", "lib", "rv64imafdc", "lp64d") + # libvr_path = os.path.join(os.path.dirname(__file__), "lib") libvr_path = os.path.join(os.path.dirname(__file__), "lib") clang_path = _get_llvm_bin_path("clang") lld_path = _get_llvm_bin_path("ld.lld") - tx8_lib = _get_tx8_path("lib") - subprocess.check_call([ - clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2", - f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d", - "-Wl,--no-dynamic-linker", - # FIXME: Hardcoded path - "/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive", - "-linstr_tx81", # Tx81 intrinsic API - "-lvr", # Wrapper API of Tx81 intrinsic - "-Wl,--no-whole-archive", "-lm", "-o", dst_path - ]) + + tx8_lib = _get_tx8_deps_path("lib") + # Build shared library for simulator or hardware + if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): + subprocess.check_call([ + clang_path, "-shared", "-O2", f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", + "-Wl,--allow-shlib-undefined", "-Wl,--no-dynamic-linker", + # FIXME: Hardcoded path + "/tmp/kernel.o", f"-L{libvr_path}", f"-L{tx8_lib}", "-Wl,--whole-archive", + "-lvr", # Wrapper API of Tx81 intrinsic + "-ltriton_cmodel", "-ltx8be_op_cmodel", "-Wl,--no-whole-archive", "-lm", "-o", dst_path + ]) + else: + # Link wrapper, kernel with Tx81 crt and intrinsics(libinstr_tx81.a) + subprocess.check_call([ + clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2", + f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d", + "-Wl,--no-dynamic-linker", + # FIXME: Hardcoded path + "/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive", + "-linstr_tx81", # Tx81 intrinsic API + "-lvr", # Wrapper API of Tx81 intrinsic + "-Wl,--no-whole-archive", "-lm", "-o", dst_path + ]) _dump_ir_if_needed([dst_path]) with open(dst_path, 'rb') as f: @@ -85,10 +99,11 @@ def _ttir_to_coreir(mod): src_path = os.path.join(tmpdir, "tt.mlir") dst_path = os.path.join(tmpdir, "core.mlir") Path(src_path).write_text(ttir_code) - tsm_opt_path = _get_tsm_opt_path() + triton_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - tsm_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize=allow-return-allocs-from-loops", + triton_opt_path, src_path, "--triton-to-core-dialects", "--core-dialects-to-mk", + "--one-shot-bufferize=allow-return-allocs-from-loops", #"--mlir-print-debuginfo", "-o", dst_path ]) @@ -107,10 +122,10 @@ def _coreir_to_mkir(mod): src_path = os.path.join(tmpdir, "core.mlir") dst_path = os.path.join(tmpdir, "mk.mlir") Path(src_path).write_text(coreir_code) - tsm_opt_path = _get_tsm_opt_path() + triton_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - tsm_opt_path, src_path, "--core-dialects-to-mk", + triton_opt_path, src_path, "--core-dialects-to-mk", #"--mlir-print-debuginfo", "-o", dst_path ]) @@ -129,10 +144,10 @@ def _coreir_to_txir(mod): src_path = os.path.join(tmpdir, "core.mlir") dst_path = os.path.join(tmpdir, "tx.mlir") Path(src_path).write_text(coreir_code) - tsm_opt_path = _get_tsm_opt_path() + triton_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - tsm_opt_path, src_path, "--expand-strided-metadata", + triton_opt_path, src_path, "--expand-strided-metadata", "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load "--mk-to-tx81", "--cse", # unused memref.subview/memref.reinterpret #"--mlir-print-debuginfo", @@ -153,20 +168,23 @@ def _txir_to_llir(mod, metadata): llvmir_path = os.path.join(tmpdir, "ll.mlir") llir_path = os.path.join(tmpdir, "ll.ir") Path(src_path).write_text(txir_code) - tsm_opt_path = _get_tsm_opt_path() + triton_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) # Tx81 and core dialects to LLVM-MLIR args = [ - tsm_opt_path, src_path, + triton_opt_path, src_path, # Use tx81-memref-to-llvm to replace "--finalize-memref-to-llvm". - "--tx81-memref-to-llvm", "--convert-scf-to-cf", "--convert-math-to-llvm", - "--convert-cf-to-llvm", # need exec before "convert-func-to-llvm" + "--tx81-memref-to-llvm", "--convert-scf-to-cf", "--test-math-polynomial-approximation", + "--convert-math-to-llvm", "--convert-cf-to-llvm", # need exec before "convert-func-to-llvm" "--convert-func-to-llvm", # need exec before "kernel-arg-buffer", otherwise un-rank memref will translate to int(rank) + ptr + # Other unconverted memref ops, eg: memref.global from scan op conversion + "--finalize-memref-to-llvm" ] - - args.append( - "--kernel-arg-buffer" - ) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel + # WORKAROUND: To replace function signature to "kernel(ptr)" + if os.getenv("VENDOR_VERSION", "") != "": + args.append( + "--kernel-arg-buffer" + ) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel # other pass args += [ @@ -181,14 +199,12 @@ def _txir_to_llir(mod, metadata): _dump_ir_if_needed([llvmir_path]) - llvm_file = os.getenv("CUSTOMIZED_IR", "") - if (llvm_file != ""): - llvmir_path = os.getenv("TRITON_DUMP_PATH", "") - - if not llvmir_path: - return - - llvmir_path = os.path.join(llvmir_path, llvm_file) + # Get spm memory use metadata + from mlir.ir import Context, Module + with Context() as ctx: + llvmir_str = Path(llvmir_path).read_text() + llvmir_module = Module.parse(llvmir_str) + metadata["shared"] = llvmir_module.operation.attributes["triton_tsm.spm_use"].value # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") @@ -244,6 +260,8 @@ def _llir_to_bin(llir: str, metadata): matches = re.findall(pattern, llir) assert len(matches) == 1 metadata["name"] = matches[0] + # Build kernel for simulator and hardware + sim_mode = os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes") with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "kernel.ll") # FIXME: Hardcoded path @@ -251,10 +269,14 @@ def _llir_to_bin(llir: str, metadata): dst_path = "/tmp/kernel.o" Path(src_path).write_text(llir) clang_path = _get_llvm_bin_path("clang++") - subprocess.check_call([ - clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-o", - dst_path - ]) + + compile_args = [clang_path, src_path, "-O2", "-c", "-fPIC", "-o", dst_path] + + # Add RISC-V specific flags when not in simulation mode + if not sim_mode: + compile_args.extend(["--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc"]) + + subprocess.check_call(compile_args) _dump_ir_if_needed([dst_path]) @@ -269,6 +291,10 @@ class TXDAOptions: num_warps: int = 0 num_ctas: int = 0 num_stages: int = 1 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 enable_warp_specialization: bool = False enable_fp_fusion: bool = False extern_libs = None @@ -313,6 +339,7 @@ def pack_metadata(self, metadata): metadata.cluster_dims[1], metadata.cluster_dims[2], metadata.name) # Our compilation pipeline isn't in python like nvidia or amd, no need to load + # dialects. See `ztc.cc` def load_dialects(self, ctx): return diff --git a/third_party/tsingmicro/backend/cpu_driver.py b/third_party/tsingmicro/backend/cpu_driver.py new file mode 100644 index 000000000..a111d9420 --- /dev/null +++ b/third_party/tsingmicro/backend/cpu_driver.py @@ -0,0 +1,387 @@ +import hashlib +import tempfile +import sysconfig + +import os, subprocess, tempfile +import importlib.util +import sysconfig + +from pathlib import Path + +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget + + +# The riscv compiler +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_SYSPATH", "") + if path == "": + raise Exception("LLVM_SYSPATH is not set.") + return os.path.join(path, "bin", bin_name) + + +# The riscv c header files and libraries path. +def _get_libc_root() -> str: + path = os.getenv("LIB_C_ROOT", "") + if path == "": + raise Exception("LIB_C_ROOT is not set.") + return path + + +# -------------------- Launcher ---------------------------- +def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return _ty_to_cpp(ty) + + +def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + +def _generate_launcher(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + kernel_arg_decls = ', '.join( + _ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants) + kernel_arg_decls += ', ' if kernel_arg_decls else '' + + kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" + for i, ty in signature.items() + if i not in constants) + kernel_parameters += ', ' if kernel_parameters else '' + + return f""" +#include +#include +#include +#include "ExecutionEngine/CRunnerUtils.h" +#include "ExecutionEngine/CRunnerUtils.cpp" + +extern "C" {{ + // Pointer type (=Memref) becomes int64_t + MemRef struct + // FIXME: understand what this int64_t is used for. + void {kernel_name}({kernel_arg_decls} + int, int, int, int, int, int); +}} + +static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{ + if (gridX*gridY*gridZ > 0) {{ + // Cast "function" to the real function type. + for(int x = 0; x < gridX; x++) {{ + for(int y = 0; y < gridY; y++) {{ + for(int z = 0; z < gridZ; z++) {{ + // Use some random type "char" here. + {' '.join(f'StridedMemRefType ptr_arg{i} = {{static_cast(arg{i}), static_cast(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")} + {kernel_name}({kernel_parameters} + gridX, gridY, gridZ, x, y, z); + }} + }} + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // [CPULauncher-specific]: We don't need the metadata below but just put them + // here anyway to be consistent with others. + // This will make updating the driver easier in the future. + + // int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + // if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + // PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + // return NULL; + // }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__ztc_ref_cpu_kernel_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___ztc_ref_cpu_kernel_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + +def compile_module(launcher_src, kernel_placeholder_name): + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + py_lib_dir = sysconfig.get_config_var("LIBDIR") + py_version = sysconfig.get_config_var("LDVERSION") + py_lib = '{name}{py_version}'.format(name="python", py_version=py_version) + cpu_backend_path = Path(__file__).resolve().parent + clang = _get_llvm_bin_path("clang++") + libc_inc = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "include") + libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib") + include_dir = os.path.join(cpu_backend_path, "include") + + def launch(gridX, gridY, gridZ, stream, cu_function, kernel_metadata, launch_metadata, launch_enter_hook, + launch_exit_hook, *args): + # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. + # Let's compile a kernel every time. + # The cu_function parameter actually contains our assembly source code. + # See CPUUtils.load_binary method. + asm_src = cu_function + kernel_name = kernel_metadata[6] # see pack_metadata in compiler.py + src = launcher_src.replace(kernel_placeholder_name, kernel_name) + + key = hashlib.md5(src.encode("utf-8") + asm_src).hexdigest() + cache = get_cache_manager(key) + name = "__ztc_ref_cpu_kernel_launcher" + filename = f"{name}.so" + cache_path = cache.get_file(filename) + + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_bytes(asm_src) + Path(launcher_src_path).write_text(src) + # Compile it together. + subprocess.check_call([ + clang, "-std=c++17", "--target=riscv64-unknown-elf", launcher_src_path, asm_src_path, + f"-I{libc_inc}", f"-I{py_include_dir}", f"-I{include_dir}", f"-I{libc_lib}", f"-L{py_lib_dir}", + "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path + ]) + + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch(gridX, gridY, gridZ, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, + *args) + + return launch + + +class CPULauncher(object): + + def __init__(self, src, metadata): + kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" + + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) + # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name + # in the following launch function. + self.launch = compile_module(launcher_src, kernel_placeholder_name) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CPUUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + # Note: + # nvidia and amd backends have their corresponding driver.c file that exposes + # get_device_properties and load_binary using python bindings. + # (see third_party/nvidia/backend/driver.c) + # These methods are then used in compiler.py to initialize handles before running + # the triton kernels. + # Since we recompile the kernel every time (see compile_module above), + # and the metadata generated by these functions aren't applicable to the cpu + # backend, just define the same functions with dummy implementation. + @staticmethod + def get_device_properties(device): + return { + "max_shared_mem": 2**20, "multiprocessor_count": None, "sm_clock_rate": None, "mem_clock_rate": None, + "mem_bus_width": None + } + + # Important note: + # Since we cannot easy pass function pointers around, we pass along the + # assembly source code so that compile_module above can recompile the + # module every time. + @staticmethod + def load_binary(name, kernel_asm, shared, device): + return (None, # module + kernel_asm, # function + None, # n_regs + None # n_spills + ) + + +class CPUDriver(DriverBase): + + def __init__(self): + super().__init__() + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + self.binary_ext = "cpuasm" + + # CPU driver won't be automatically chosen unless explicitly set through + # triton.runtime.driver.set_active(CPUDriver()) + @staticmethod + def is_active(): + return False + + def get_device_capability(self): + return ("cpu", 0) + + def get_current_stream(self, device): + return None + + def get_current_device(self): + # CPU doesn't have a device to return. Return something. + return "cpu" + + def set_current_device(self, device): + # CPU doesn't have a device to set + assert device == "cpu" + return + + def get_current_target(self): + return GPUTarget("cpu", 0, 0) + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py index dc646291b..6f37a1963 100644 --- a/third_party/tsingmicro/backend/driver.py +++ b/third_party/tsingmicro/backend/driver.py @@ -18,17 +18,47 @@ from triton.backends.compiler import GPUTarget +def _get_tx8_deps_path(sub_name: str) -> str: + path = os.getenv("TX8_DEPS_ROOT", "") + if path == "": + raise Exception("TX8_DEPS_ROOT is not set.") + return os.path.join(path, sub_name) + + +dirname = os.path.dirname(os.path.realpath(__file__)) +if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): + scheme = sysconfig.get_default_scheme() + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + include_dirs = [_get_tx8_deps_path("include"), py_include_dir] + library_dirs = [_get_tx8_deps_path("lib")] + libraries = ["triton_cmodel", "tx8be_op_cmodel", "neuralcore_qemu"] +else: + include_dirs = [ + os.path.join(dirname, "include"), + _get_tx8_deps_path("include"), + os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), + os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include") + ] + library_dirs = [ + os.path.join(dirname, "lib"), + _get_tx8_deps_path("lib"), + os.path.join(sysconfig.get_path('platlib'), "torch", "lib") + ] + libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] + + def extend_torch(): import torch from torch.utils import cpp_extension, rename_privateuse1_backend, generate_methods_for_privateuse1_backend module = cpp_extension.load( name="txda", sources=[os.path.dirname(__file__) + "/txda_device.cpp"], - #runtime include path - extra_include_paths=[""], - #runtime *.so path - extra_ldflags=[""], - extra_cflags=["-g"], + extra_include_paths=[os.path.realpath(_get_tx8_deps_path("include"))], + extra_ldflags=["-L" + os.path.realpath(_get_tx8_deps_path("lib")), "-ltx8_runtime"], + # extra_cflags=["-g"], verbose=True, ) torch.utils.rename_privateuse1_backend("txda") @@ -36,30 +66,6 @@ def extend_torch(): generate_methods_for_privateuse1_backend(for_storage=True) -def _get_tx8_path(bin_name: str) -> str: - path = os.getenv("TX8_HOME", "") - if path == "": - raise Exception("TX8_HOME is not set.") - return os.path.join(path, bin_name) - - -dirname = os.path.dirname(os.path.realpath(__file__)) -include_dirs = [ - os.path.join(dirname, "include"), - os.path.realpath(_get_tx8_path("include")), - os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), - os.path.join(sysconfig.get_path('platlib'), "torch", "include"), - os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), - os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include") -] -library_dirs = [ - os.path.join(dirname, "lib"), - os.path.realpath(_get_tx8_path("lib")), - os.path.join(sysconfig.get_path('platlib'), "torch", "lib") -] -libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] - - def _dump_ir_if_needed(files): path = os.getenv("TRITON_DUMP_PATH", "") if not path: @@ -98,7 +104,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] - cc_cmd += [f"-Wl,-rpath,{dir}" for dir in library_dirs] subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) return so @@ -114,6 +119,7 @@ def compile_native(src, name): src_path = os.path.join(tmpdir, f"{name}.cpp") with open(src_path, "w") as f: f.write(src) + f.flush() _dump_ir_if_needed([src_path]) so = _build(name, src_path, tmpdir, library_dirs, include_dirs, libraries) with open(so, "rb") as f: @@ -162,11 +168,11 @@ def _extracted_type(ty): def _format_of(ty): if isinstance(ty, tuple): - val = ''.join(map(format_of, ty)) + val = ''.join(map(_format_of, ty)) return f"({val})" if ty[0] == '*': return "O" - if ty in ("constexpr", "nvTmaDesc"): + if ty in ("constexpr"): return "O" return { "float": "f", @@ -184,19 +190,189 @@ def _format_of(ty): def make_launcher(constants, signature, kernel_name): - # Basic declarations + # Basic declarations. Arguments in triton kernel. arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") args_format = ''.join([_format_of(ty) for ty in signature.values()]) - format = "iiiOOOOOO" + args_format + format = "iiiOKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Parameters to pass to the kernel function + # Parameters to pass to the kernel function. Arguments in triton kernel except constants. + kernel_arg_decls = ', '.join( + f"{_ty_to_cpp(ty)} arg{i}" if ty[0] != "*" else f"uint64_t tx81_ptr{i}, {_ty_to_cpp(ty)} ptr_arg{i}" + for i, ty in signature.items() + if ty != "constexpr") + kernel_arg_decls += ', ' if kernel_arg_decls else '' + kernel_parameters = ', '.join( - f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" + f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr") kernel_parameters += ', ' if kernel_parameters else '' + # Simulation or hardware + if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): + # generate glue code for tile-sim + return f""" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common_base.h" +#include "instr_def.h" +#include "common_tensor.h" +#include "cmodel.h" + + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +using kernel_ptr_t = void(*)({kernel_arg_decls}int, int, int, int, int, int); + +static void _launch(int gridX, int gridY, int gridZ, {kernel_arg_decls}kernel_ptr_t kernel_ptr) {{ + if (gridX*gridY*gridZ <= 0) + return; // No work to do + + // Cast "function" to the real function type. + for (uint32_t z = 0; z < gridZ; ++z) {{ + for (uint32_t y = 0; y < gridY; ++y) {{ + for (uint32_t x = 0; x < gridX; ++x) {{ + __set_pid(x, y, z); + (*kernel_ptr)({kernel_parameters}gridX, gridY, gridZ, x, y, z); + }} + }} + }} +}} + + +typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + TileSimHandle *sim_handle = q_tilesim_create(RCESIM_LOG_DEBUG); + set_sim_handle(sim_handle, NULL); + q_tilesim_set_logFile(sim_handle, "aa.log"); + + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + + PyObject * py_obj_stream = NULL; + void * pKrnl = NULL; + + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook + {args_list})) {{ + return NULL; + }} + + // FIXME: Steam is PyNone + // void *pStream = PyLong_AsVoidPtr(py_obj_stream); + kernel_ptr_t kernel_ptr = reinterpret_cast((PyObject*)pKrnl); + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items() if ty != "constexpr"])}; + + _launch(gridX, gridY, gridZ, {', '.join(f"0, ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items() if ty != "constexpr")} {',' if len(kernel_parameters) > 0 else ''} kernel_ptr); + + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + # generate glue code for tx8 board return f""" #include #include @@ -215,22 +391,6 @@ def make_launcher(constants, signature, kernel_name): #include "hrt_interface.h" #include "hrt_common.h" -// The design of kernel argument buffer: -// The offset starts from the whole kernel buffer -// +------------------------------------------------------------------------+ -// | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 8 bytes | 8 bytes | -// | No. kargs | gridX | gridY | gridZ | karg1 offset | karg2 offset | -// +------------------------------------------------------------------------+ -// .......................... Metadata buffer................................ -// -// +------------------------------------------------------------------------+ -// | ... | 8 bytes | n bytes | n bytes | | n bytes | -// | ... | kargn offset | karg1 data | karg2 data | ... | kargn data | -// +------------------------------------------------------------------------+ -// ^ ^ ^ -// karg1 offset karg2 offset kargn offset -// ... Metadata buffer... | ............ kernel arg buffer .................. - enum DATA_TYPE {{ SCALAR, POINT, @@ -258,211 +418,21 @@ def make_launcher(constants, signature, kernel_name): }}; -extern "C" {{ - // The kernel arguments includes: - // 1. The actual kernel argument in arg_decls - // 2. The group size: gridX, gridY, gridZ - // 3 The thread id in each direction: x, y, z - void {kernel_name}({arg_decls}, int, int, int, int, int, int); -}} - -// Global device vector -static std::vector g_tx81_devices; -static bool g_runtime_initialized = false; - // FIXME: Hardcoded path std::string chip_out = "/tmp/chip_out/node0/"; std::string kernel_file = "/tmp/kernel.so"; std::string kernel_fun_name = "{kernel_name}"; uint32_t sharedMemBytes = 0; +TsmDevice* device; typedef void* Stream_t; -static uint64_t get_phy_addr(uint64_t logic_addr) {{ - uint32_t card_id; - uint64_t addr; - uint64_t size; - TsmMemGetInfo(logic_addr, card_id, addr, size); - return addr; -}} - - -// Initialize Tx81 runtime -bool init_tx81_runtime() {{ - if (g_runtime_initialized) {{ - return true; // Already initialized - }} - - // Initialize the Tx81 runtime - if (TsmInitRuntime() != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); - return false; - }} - - // Get device count - uint32_t device_num = 0; - if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to get Tx81 device count or no devices found"); - TsmDeInitRuntime(); - return false; - }} - - // FIXME: Hardcoded - // Set up devices - for simplicity, we're using a 1x1 configuration - uint32_t first_phy_id = 0; - uint32_t card_x = 1; - uint32_t card_y = 1; - - TsmDevice *dev = new TsmDevice(); - if (TsmSetDevice(&dev, 0, first_phy_id) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); - TsmDeInitRuntime(); - return false; - }} - g_tx81_devices.push_back(dev); - - // FIXME: Hardcoded - TsmModel *new_model = new TsmModel(); - // Create a vector of models - std::vector kmodel_vec = {{new_model}}; - std::string option = "-O2"; - CompileOption compl_option = {{}}; - compl_option.comp_enable = 0; // Use prebuilt binary - compl_option.chip_x = 1; //单卡 - compl_option.chip_y = 1; - compl_option.check_enable = true; - compl_option.enable_kcore_bin = 1; - compl_option.enable_kcore_so = 1; - new_model->case_dir = chip_out; - - for (TsmDevice * dev : g_tx81_devices) {{ - if (TsmCompileMultiGraph(dev, *new_model, option, compl_option) != RET_SUCCESS) {{ - for (uint32_t dev_index = 0; dev_index < g_tx81_devices.size(); ++dev_index) {{ - if (TsmResetDevice(g_tx81_devices[dev_index]) != RET_SUCCESS) {{ - return false; - }} - }} - return false; - }} - }} - - // Initialize all devices - for (auto* dev : g_tx81_devices) {{ - if (TsmLaunch(dev, *new_model) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); - TsmReleaseDevice(dev); - TsmResetDevice(dev); - return false; - }} - - if (TsmSetMonitorInfo(dev) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); - TsmReleaseDevice(dev); - TsmResetDevice(dev); - return false; - }} - }} - - delete new_model; - g_runtime_initialized = true; - - return true; -}} - -// Clean up Tx81 runtime resources -static PyObject* cleanup_tx81_runtime(PyObject* self, PyObject* args) {{ - if (!g_runtime_initialized) {{ - Py_RETURN_NONE; - }} - - for (auto* dev : g_tx81_devices) {{ - if (TsmSetTerminate(dev) != RET_SUCCESS) {{ - Py_RETURN_NONE; - }} - // Reset and release each device - TsmReleaseDevice(dev); - TsmResetDevice(dev); - delete dev; - }} - g_tx81_devices.clear(); - TsmDeInitRuntime(); - g_runtime_initialized = false; - Py_RETURN_NONE; -}} - -TSM_RETCODE argsToDevMemArray(TsmDevice *dev, std::vector &kargs, - std::vector &rtKargs, std::vector &devAddrs) {{ - int count = 0; - for (KernelArg& karg : kargs) {{ - if (karg.data_type == POINT) {{ - TsmDevicePtr dev_buffer; - if (TsmDeviceMalloc(dev, dev_buffer, karg.size) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceMalloc"); - return RET_ERROR; - }} - - if (TsmMemcpyH2D(dev_buffer, karg.data.ptr, karg.size) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); - return RET_ERROR; - }} - devAddrs.push_back(dev_buffer); - // FIXME: rank - rtKargs.push_back(1); - rtKargs.push_back(get_phy_addr(dev_buffer)); - - count++; - }} - else {{ - rtKargs.push_back(karg.data.scalar); - }} - }} - return RET_SUCCESS; -}} - -TSM_RETCODE devMemArrayToArgs(TsmDevice *dev, std::vector &kargs, - std::vector &devAddrs) {{ - - int count = 0; - for (KernelArg& karg : kargs) {{ - if (karg.data_type == POINT) {{ - uint64_t dev_buffer = devAddrs[count++]; - if (TsmMemcpyD2H(karg.data.ptr, dev_buffer, karg.size) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); - return RET_ERROR; - }} - }} - }} - return RET_SUCCESS; -}} - -TSM_RETCODE devMemFree(TsmDevice *dev, std::vector &devAddrs) {{ - for (uint64_t dev_buffer : devAddrs) {{ - if (TsmDeviceFree(dev_buffer) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceFree"); - return RET_ERROR; - }} - }} - return RET_SUCCESS; -}} - -TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) {{ - if (bootpm_dev != 0) {{ - if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) {{ - return RET_ERROR; - }} - bootpm_dev = 0; - }} - return RET_SUCCESS; -}} - static void _launch(int gridX, int gridY, int gridZ, std::vector kargs) {{ - std::vector &devices = g_tx81_devices; - if (gridX*gridY*gridZ <= 0) {{ return; // No work to do }} - // TODO::mv + // TODO::mv uint64_t kernel_len = 0; uint8_t* kernel_ptr = read_file_data(kernel_file, kernel_len); if (kernel_ptr == nullptr) {{ @@ -471,46 +441,31 @@ def make_launcher(constants, signature, kernel_name): return; }} - // prepare data/ load kernel/run/unload kernel/get out data/release memory - for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ - // Allocate the device memory for all kernel arguments - std::vector devAddrs; - std::vector rtKargs; - - if (argsToDevMemArray(devices[dev_index], kargs, rtKargs, devAddrs) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to argsToDevMemArray"); - TsmDeInitRuntime(); - return; - }} - - rtKargs.push_back(gridX); - rtKargs.push_back(gridY); - rtKargs.push_back(gridZ); - rtKargs.push_back(0); - rtKargs.push_back(0); - rtKargs.push_back(0); - - // TSM_RETCODE TsmKernelLaunch(TsmDevice *dev, const char *func_name, void *kernel_ptr, uint32_t kernel_len, - // uint32_t grid_dim, uint32_t block_dim, void *args, uint32_t args_len); - if (TsmKernelLaunch(devices[dev_index], kernel_fun_name.c_str(), (void*)kernel_ptr, kernel_len, - gridX, 1, (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t)) != RET_SUCCESS){{ - PyErr_SetString(PyExc_RuntimeError, "Failed to TsmKernelLaunch"); - TsmDeInitRuntime(); - }} - if (devMemArrayToArgs(devices[dev_index], kargs, devAddrs) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to devMemArrayToArgs"); - TsmDeInitRuntime(); - return; - }} - - // getchar(); - - // TsmUnloadKernel(devices[dev_index], kmodel_vec); - - if (devMemFree(devices[dev_index], devAddrs) != RET_SUCCESS) {{ - return; + // Allocate the device memory for all kernel arguments + std::vector rtKargs; + for (KernelArg& karg : kargs) {{ + if (karg.data_type == POINT) {{ + rtKargs.push_back(1); + rtKargs.push_back((uint64_t)(karg.data.ptr)); + }} else {{ + rtKargs.push_back((uint64_t)(karg.data.ptr)); }} }} + rtKargs.push_back(gridX); + rtKargs.push_back(gridY); + rtKargs.push_back(gridZ); + rtKargs.push_back(0); + rtKargs.push_back(0); + rtKargs.push_back(0); + + // TSM_RETCODE TsmKernelLaunch(TsmDevice *dev, const char *func_name, uint64_t kernel_host, uint64_t kernel_len, + // Dim3 grid_dim, Dim3 block_dim, void *args, uint32_t args_len); + if (TsmKernelLaunch(device, kernel_fun_name.c_str(), (uint64_t)kernel_ptr, kernel_len, + Dim3({{(uint32_t)gridX, (uint32_t)gridY, (uint32_t)gridZ}}), Dim3({{1u, 1u, 1u}}), + (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t)) != RET_SUCCESS){{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmKernelLaunch"); + TsmDeInitRuntime(); + }} }} // Structure to represent a device pointer @@ -531,18 +486,12 @@ def make_launcher(constants, signature, kernel_name): return contiguous_tensor.data_ptr(); }} -static PyObject* init_runtime(PyObject* self, PyObject* args) {{ - const char* _chip_out; - if (!PyArg_ParseTuple(args, "s", &_chip_out)) {{ - return NULL; - }} - chip_out = _chip_out; - - // Initialize Tx81 runtime during module import - if (!init_tx81_runtime()) {{ +static PyObject* get_device_ptr(PyObject* self, PyObject* args) {{ + uint64_t dev_ptr; + if (!PyArg_ParseTuple(args, "K", &dev_ptr)) {{ return NULL; }} - + device = (TsmDevice *)dev_ptr; return Py_None; }} @@ -553,39 +502,20 @@ def make_launcher(constants, signature, kernel_name): PyObject *launch_exit_hook = NULL; PyObject *kernel_metadata = NULL; PyObject *launch_metadata = NULL; - // FIXME: Extra 2 args: - PyObject *dummy1 = NULL; - PyObject *dummy2 = NULL; + PyObject * py_obj_stream = NULL; + void * pKrnl = NULL; + // Define the actual kernel arguments {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} // Init kernel arguments from python side - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook, - &dummy1, &dummy2{args_list})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook + {args_list})) {{ return NULL; }} -#if 0 // FIXME: kernel_metadata is not correctly inited - // Extract metadata for consistency with other drivers - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, - &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - return NULL; - }} - - // Call the enter hook if provided - if (launch_enter_hook != Py_None) {{ - PyObject* hook_args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, hook_args); - Py_DECREF(hook_args); - if (!ret) - return NULL; - }} -#endif - // Construct a data kernel arguments list data structure std::vector kargs; //{' '.join([f"kargs.emplace_back(_arg{i}, PyObject_Size(_arg{i})*4);" if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} @@ -616,7 +546,7 @@ def make_launcher(constants, signature, kernel_name): // Python module method definitions static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{"init_runtime", init_runtime, METH_VARARGS, "Init runtime with chip_out dir"}}, + {{"get_device_ptr", get_device_ptr, METH_VARARGS, "Get txda current device"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; @@ -629,13 +559,6 @@ def make_launcher(constants, signature, kernel_name): ModuleMethods }}; -static PyMethodDef cleanup_method = {{ - "cleanup_tx81_runtime", - (PyCFunction)cleanup_tx81_runtime, - METH_NOARGS, - "Cleanup Tx81 runtime resources" -}}; - // Python module initialization function PyMODINIT_FUNC PyInit___triton_launcher(void) {{ PyObject *m = PyModule_Create(&ModuleDef); @@ -644,19 +567,6 @@ def make_launcher(constants, signature, kernel_name): }} PyModule_AddFunctions(m, ModuleMethods); - - // Register an atexit handler to cleanup Tx81 runtime - PyObject* atexit_module = PyImport_ImportModule("atexit"); - if (atexit_module) {{ - PyObject* cleanup_func = PyCFunction_New(&cleanup_method, NULL); - if (cleanup_func) {{ - PyObject* result = PyObject_CallMethod(atexit_module, "register", "O", cleanup_func); - Py_XDECREF(result); - Py_DECREF(cleanup_func); - }} - Py_DECREF(atexit_module); - }} - return m; }} """ @@ -677,6 +587,32 @@ def __init__(self): self.get_device_properties = mod.get_device_properties +class SimulatorUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(SimulatorUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + pass + + def load_binary(self, name, kernel, shared_mem, device): + with tempfile.NamedTemporaryFile(mode="wb", suffix=".so", delete=False) as f: + f.write(kernel) + f.flush() + import ctypes + + # Load the kernel ptr + lib = ctypes.cdll.LoadLibrary(f.name) + fn_ptr = getattr(lib, name) + fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value + return (lib, fn_ptr_as_void_p, 0, 0) + + def get_device_properties(self, *args): + return {"max_shared_mem": 1024 * 1024 * 3} + + # Launch cross compiled runtime program on controller class TXDALauncher(object): @@ -689,10 +625,8 @@ def __init__(self, src, metadata): # Compiler runtime kernel launcher source code launcher_src = make_launcher(constants, signature, src.fn.__name__) mod = compile_native(launcher_src, "__triton_launcher") + self.get_device_ptr = mod.get_device_ptr self.launch = mod.launch - chip_out = os.path.join(_get_tx8_path("chip_out"), "node0") - chip_out = chip_out + os.sep - mod.init_runtime(chip_out) def __call__(self, *args, **kwargs): # args: 0: gridX, 1: gridY, 2: gridZ, @@ -700,19 +634,24 @@ def __call__(self, *args, **kwargs): # 5: a tuple(0, 0, False, 1, 1, 1, 'add_kernel'), # this is probably kernel metadata # 6: None, 7: None, 8: None, # 9~N: Actual triton kernel args. + import torch + device = torch.txda.get_device() + self.get_device_ptr(device) self.launch(*args, **kwargs) class TXDADriver(GPUDriver): def __init__(self): + import torch super().__init__() - extend_torch() - self.utils = TXDAUtils() + if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): + self.utils = SimulatorUtils() + else: + self.utils = TXDAUtils() self.launcher_cls = TXDALauncher - import torch # Needs to overwrite GPUDriver base methods - self.get_current_stream = torch.txda.current_stream + self.get_current_stream = self.get_txda_stream self.get_current_device = torch.txda.current_device self.set_current_device = torch.txda.set_device atexit.register(torch.txda.cleanup_device) @@ -720,12 +659,15 @@ def __init__(self): @staticmethod def is_active(): try: - #import torch - #return torch.txda.is_available() - return True + import torch + extend_torch() + return torch.txda.is_available() except ImportError: return False + def get_txda_stream(self, device): + return None + def get_current_target(self): capability = 1 warp_size = 16 @@ -733,7 +675,9 @@ def get_current_target(self): def get_active_torch_device(self): import torch - # torch.txda.init_device() + chip_out = _get_tx8_deps_path("chip_out") + chip_out = chip_out + os.sep + torch.txda.init_device(chip_out) return torch.device("txda", self.get_current_device()) def get_benchmarker(self): diff --git a/third_party/tsingmicro/backend/txda_device.cpp b/third_party/tsingmicro/backend/txda_device.cpp index ac46d67ac..b2def54d5 100644 --- a/third_party/tsingmicro/backend/txda_device.cpp +++ b/third_party/tsingmicro/backend/txda_device.cpp @@ -11,9 +11,13 @@ #include #include +#include #include #include +#include "hrt_common.h" +#include "hrt_interface.h" + namespace at { namespace detail { @@ -23,12 +27,124 @@ C10_REGISTER_GUARD_IMPL( } } // namespace at +// Global device vector +std::vector g_txda_devices; +static bool g_runtime_initialized = false; +static int device_id = 0; +static int stream_id = 0; +// std::string chip_out = "/tmp/chip_out/node0/"; + +bool init_device(std::string chip_out) { + if (g_runtime_initialized) { + return true; + } + + bool is_succ_init = false; + auto guard = std::shared_ptr(nullptr, [&is_succ_init](void *) { + if (!is_succ_init) { + for (auto *dev : g_txda_devices) { + TsmNpuPowerOff(dev); + TsmResetDevice(dev); + delete dev; + } + g_txda_devices.clear(); + TsmDeInitRuntime(); + g_runtime_initialized = false; + } + }); + + if (TsmInitRuntime(true) != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize txda runtime"); + return false; + } + // Get device count + uint32_t device_num = 0; + if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) { + PyErr_SetString(PyExc_RuntimeError, + "Failed to get txda device count or no devices found"); + TsmDeInitRuntime(); + return false; + } + + device_id = 0; + + TsmDevice *dev = new TsmDevice(); + if (TsmSetDevice(&dev, 0, UINT32_MAX) != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to set txda devices"); + TsmDeInitRuntime(); + return false; + } + g_txda_devices.push_back(dev); + + // TSM_RETCODE TsmNpuPowerOn(TsmDevice *dev, std::vector + // kcore_file_list); + std::vector kcore_file_list; + std::string kcore_bin = chip_out + "kcore_fw.bin"; + for (int i = 0; i < 16; i++) { + kcore_file_list.push_back(kcore_bin); + } + + if (TsmNpuPowerOn(dev, kcore_file_list) != RET_SUCCESS || + TsmSetMonitorInfo(dev) != RET_SUCCESS) { + return false; + } + + // delete new_model; + g_runtime_initialized = true; + printf("====init_txda_device====success=======\n"); + is_succ_init = true; + return true; +} + +void set_device(int id) { device_id = id; } + +bool cleanup_device() { + if (!g_runtime_initialized) { + return true; + } + for (auto *dev : g_txda_devices) { + // Reset and release each device + TsmNpuPowerOff(dev); + TsmResetDevice(dev); + delete dev; + } + g_txda_devices.clear(); + TsmDeInitRuntime(); + g_runtime_initialized = false; + printf("====cleanup_txda_runtime==== release success=======\n"); + return true; +} + +int current_device() { return device_id; } + +int current_stream(int id) { return stream_id; } + +uint64_t get_device() { return (uint64_t)g_txda_devices[device_id]; } + +// TODO: +bool is_available() { return true; } + +void synchronize() {} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_device", &get_device, "get txda device"); + m.def("is_available", &is_available, "is tx device available"); + m.def("init_device", &init_device, "initialize tx device"); + m.def("current_device", ¤t_device, "get current tx device"); + m.def("current_stream", ¤t_stream, "get current tx stream"); + m.def("set_device", &set_device, "set tx device"); + m.def("cleanup_device", &cleanup_device, "cleanup tx device"); + m.def("synchronize", &synchronize, "synchronize all threads in block"); +} + struct TXDADeviceAllocator final : at::Allocator { - TXDADeviceAllocator() {} + TXDADeviceAllocator() = default; at::DataPtr allocate(size_t nbytes) override { - void *data = c10::alloc_cpu(nbytes); - return {data, nullptr, &ReportAndDelete, + TsmDevicePtr data; + int dev_id = current_device(); + TsmDeviceMalloc(g_txda_devices[dev_id], data, (uint64_t)nbytes); + return {(void *)data, nullptr, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; } @@ -36,8 +152,7 @@ struct TXDADeviceAllocator final : at::Allocator { if (!ptr) { return; } - // TsmDeviceFree((uint64_t)ptr) - c10::free_cpu(ptr); + TsmDeviceFree((uint64_t)ptr); } at::DeleterFnPtr raw_deleter() const override { return &ReportAndDelete; } @@ -50,12 +165,9 @@ struct TXDADeviceAllocator final : at::Allocator { static TXDADeviceAllocator global_txda_alloc; REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_txda_alloc); -// to.Device at::Tensor txda_to_device(const at::Tensor &self, at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format) { - // TsmMemcpyH2D(); - TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "only support data transfer between cpu and txda"); @@ -69,18 +181,15 @@ at::Tensor txda_to_device(const at::Tensor &self, at::Device device, if (device != at::DeviceType::CPU) { return at::empty(self.sizes(), self.options()); } - auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format); memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes()); return out; } -// _copy_from +// TsmMemcpyH2D() or TsmMemcpyD2H() at::Tensor txda__copy_from(const at::Tensor &self, const at::Tensor &dst, bool non_blocking) { - // TsmMemcpyD2H(); - TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "only support data transfer between cpu and txda"); @@ -88,14 +197,37 @@ at::Tensor txda__copy_from(const at::Tensor &self, const at::Tensor &dst, dst.device().type() == c10::DeviceType::PrivateUse1, "only support data transfer between cpu and txda"); - // Some dummy asserts for the basic use case: inputs are the same size / - // dtype, all contiguous. TORCH_CHECK(self.sizes() == dst.sizes()); TORCH_CHECK(self.scalar_type() == dst.scalar_type()); TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); - std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), - self.storage().nbytes()); + if (self.is_cpu()) { + // printf("H2D self: 0x%lx, dst: 0x%lx, size: 0x%lx\n", + // (uint64_t)self.storage().data_ptr().get(), + // (uint64_t)dst.storage().data_ptr().get(), + // (uint64_t)self.storage().nbytes()); + + auto ret = TsmMemcpyH2D((uint64_t)dst.storage().data_ptr().get(), + (const void *)self.storage().data_ptr().get(), + self.storage().nbytes()); + if (ret != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); + } + TORCH_CHECK(ret == RET_SUCCESS, "==H2DMemArray Error==="); + } else { + // printf("D2H self: 0x%lx, dst: 0x%lx, size: 0x%lx\n", + // (uint64_t)self.storage().data_ptr().get(), + // (uint64_t)dst.storage().data_ptr().get(), + // (uint64_t)self.storage().nbytes()); + + auto ret = TsmMemcpyD2H((const void *)dst.storage().data_ptr().get(), + (uint64_t)self.storage().data_ptr().get(), + self.storage().nbytes()); + if (ret != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyD2H"); + } + TORCH_CHECK(ret == RET_SUCCESS, "==D2HMemArray Error==="); + } return dst; } @@ -150,31 +282,12 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("as_strided", &txda_as_strided); } -bool init_device() { - // return init_txda_runtime(); - return true; +void custom_cpu_fallback(const c10::OperatorHandle &op, + torch::jit::Stack *stack) { + printf("custom_cpu_fallback \n"); + at::native::cpu_fallback(op, stack); } -bool cleanup_device() { - // cleanup_txda_runtime(); - return true; -} - -int current_device() { return 0; } - -int current_stream(int id) { return 0; } - -void set_device(int id) {} - -c10::Device get_txda_device() { - return c10::Device(c10::DeviceType::PrivateUse1, 0); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("current_device", ¤t_device, "get current tx device"); - m.def("current_stream", ¤t_stream, "get current tx stream"); - m.def("set_device", &set_device, "set tx device"); - m.def("get_txda_device", &get_txda_device, "get tx device"); - m.def("init_device", &init_device, "initialize tx device"); - m.def("cleanup_device", &cleanup_device, "cleanup tx device"); +TORCH_LIBRARY_IMPL(_, PrivateUse1, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } diff --git a/third_party/tsingmicro/bin/CMakeLists.txt b/third_party/tsingmicro/bin/CMakeLists.txt index 5dca7362a..095063c42 100644 --- a/third_party/tsingmicro/bin/CMakeLists.txt +++ b/third_party/tsingmicro/bin/CMakeLists.txt @@ -3,13 +3,11 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) -add_llvm_executable(tsingmicro-opt tsingmicro-opt.cpp PARTIAL_SOURCES_INTENDED) +# TODO: Workaround include path +# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/tsingmicro/include) +# include_directories(${CMAKE_CURRENT_BINARY_DIR}/../third_party/tsingmicro/include) -if (DEFINED TSM_BACKEND_DIR) - set(TSM_BIN_OUT ${TSM_BACKEND_DIR}/bin) -else () - set(TSM_BIN_OUT ${CMAKE_BINARY_DIR}/bin) -endif () +add_llvm_executable(tsingmicro-opt tsingmicro-opt.cpp PARTIAL_SOURCES_INTENDED) # TODO: what's this? llvm_update_compile_flags(tsingmicro-opt) @@ -27,81 +25,87 @@ target_link_libraries(tsingmicro-opt PRIVATE MLIROptLib MLIRPass MLIRTransforms -) -set_target_properties(tsingmicro-opt PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${TSM_BIN_OUT} + MLIRMathTestPasses ) mlir_check_all_link_libraries(tsingmicro-opt) -add_llvm_executable(tsingmicro-reduce tsingmicro-reduce.cpp PARTIAL_SOURCES_INTENDED) -mlir_check_all_link_libraries(tsingmicro-reduce) +# add_llvm_executable(tsingmicro-reduce tsingmicro-reduce.cpp PARTIAL_SOURCES_INTENDED) +# mlir_check_all_link_libraries(tsingmicro-reduce) -llvm_update_compile_flags(tsingmicro-reduce) -target_link_libraries(tsingmicro-reduce PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} - ${triton_libs} - # tests - TritonTestAnalysis - TritonTestDialectTritonGPU - TritonAMDGPUTestAnalysis - # MLIR core - MLIRReduceLib - MLIRPass - MLIRTransforms -) -set_target_properties(tsingmicro-reduce PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${TSM_BIN_OUT} -) +# llvm_update_compile_flags(tsingmicro-reduce) +# target_link_libraries(tsingmicro-reduce PRIVATE +# ${dialect_libs} +# ${conversion_libs} +# ${extension_libs} +# ${triton_libs} +# # tests +# TritonTestAnalysis +# TritonTestDialectTritonGPU +# TritonAMDGPUTestAnalysis +# # MLIR core +# MLIRReduceLib +# MLIRPass +# MLIRTransforms +# MLIRMathTestPasses +# ) -mlir_check_all_link_libraries(tsingmicro-reduce) +# mlir_check_all_link_libraries(tsingmicro-reduce) -add_llvm_executable(tsingmicro-lsp tsingmicro-lsp.cpp PARTIAL_SOURCES_INTENDED) -mlir_check_all_link_libraries(tsingmicro-lsp) +# add_llvm_executable(tsingmicro-lsp tsingmicro-lsp.cpp PARTIAL_SOURCES_INTENDED) + +# llvm_update_compile_flags(tsingmicro-lsp) +# target_link_libraries(tsingmicro-lsp PRIVATE +# ${dialect_libs} +# ${conversion_libs} +# ${extension_libs} +# ${triton_libs} +# # tests +# TritonTestAnalysis +# TritonTestDialectTritonGPU +# TritonAMDGPUTestAnalysis +# # MLIR core +# MLIRLspServerLib +# MLIRPass +# MLIRTransforms +# MLIRMathTestPasses +# ) + +# mlir_check_all_link_libraries(tsingmicro-lsp) -llvm_update_compile_flags(tsingmicro-lsp) -target_link_libraries(tsingmicro-lsp PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} - ${triton_libs} - # tests - TritonTestAnalysis - TritonTestDialectTritonGPU - TritonAMDGPUTestAnalysis - # MLIR core - MLIRLspServerLib - MLIRPass - MLIRTransforms -) -set_target_properties(tsingmicro-lsp PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${TSM_BIN_OUT} -) -mlir_check_all_link_libraries(tsingmicro-lsp) +# add_llvm_executable(tsingmicro-llvm-opt +# tsingmicro-llvm-opt.cpp +# PARTIAL_SOURCES_INTENDED +# DEPENDS +# intrinsics_gen +# SUPPORT_PLUGINS +# ) +# target_link_libraries(tsingmicro-llvm-opt PRIVATE +# TritonLLVMIR -add_llvm_executable(tsingmicro-llvm-opt - tsingmicro-llvm-opt.cpp +# LLVMAnalysis +# LLVMCore +# LLVMSupport +# LLVMOption +# LLVMCodeGen +# ) +# export_executable_symbols_for_plugins(tsingmicro-llvm-opt) - PARTIAL_SOURCES_INTENDED - DEPENDS - intrinsics_gen - SUPPORT_PLUGINS - ) -target_link_libraries(tsingmicro-llvm-opt PRIVATE - TritonLLVMIR - LLVMAnalysis - LLVMCore - LLVMSupport - LLVMOption - LLVMCodeGen - ) -export_executable_symbols_for_plugins(tsingmicro-llvm-opt) +# add_llvm_executable(tsingmicro-tensor-layout tsingmicro-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) +# target_link_libraries(tsingmicro-tensor-layout PRIVATE +# ${triton_libs} +# ${conversion_libs} +# ${extension_libs} +# ${dialect_libs} +# TritonTestAnalysis +# TritonTestDialectTritonGPU +# TritonAMDGPUTestAnalysis +# MLIRMathTestPasses +# ) -set_target_properties(tsingmicro-llvm-opt PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${TSM_BIN_OUT} +install(TARGETS tsingmicro-opt + RUNTIME DESTINATION ${INSTALL_TSINGMICRO_DIR}/bin ) diff --git a/third_party/tsingmicro/bin/RegisterTritonDialects.h b/third_party/tsingmicro/bin/RegisterTritonDialects.h index 95cb7fbd5..0b6775404 100644 --- a/third_party/tsingmicro/bin/RegisterTritonDialects.h +++ b/third_party/tsingmicro/bin/RegisterTritonDialects.h @@ -22,6 +22,7 @@ #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -31,7 +32,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "magic-kernel/Dialect/IR/MagicKernelDialect.h" -#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" #include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" @@ -44,7 +44,6 @@ #include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" #include "magic-kernel/Conversion/LinalgToMK/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" #include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" #include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" @@ -58,6 +57,7 @@ namespace mlir { namespace test { +void registerTestMathPolynomialApproximationPass(); void registerTestAliasPass(); void registerTestAlignmentPass(); void registerTestAllocationPass(); @@ -120,10 +120,13 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + // Math dialect passes + mlir::test::registerTestMathPolynomialApproximationPass(); + // FIXME: May not need all of these // mlir::registerAllDialects(registry); // Register all external models. - // affine::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::affine::registerValueBoundsOpInterfaceExternalModels(registry); mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(registry); mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); mlir::arith::registerBufferViewFlowOpInterfaceExternalModels(registry); @@ -145,10 +148,12 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); mlir::memref::registerValueBoundsOpInterfaceExternalModels(registry); mlir::memref::registerMemorySlotExternalModels(registry); + mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(registry); mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); mlir::scf::registerValueBoundsOpInterfaceExternalModels(registry); mlir::shape::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerFindPayloadReplacementOpInterfaceExternalModels( registry); @@ -156,26 +161,28 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::tensor::registerSubsetOpInterfaceExternalModels(registry); mlir::tensor::registerTilingInterfaceExternalModels(registry); mlir::tensor::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::vector::registerBufferizableOpInterfaceExternalModels(registry); mlir::vector::registerSubsetOpInterfaceExternalModels(registry); mlir::vector::registerValueBoundsOpInterfaceExternalModels(registry); mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + // This is need for the Bufferization pass(one-shot bufferization) mlir::registerAllExtensions(registry); mlir::mk::registerBufferizableOpInterfaceExternalModels(registry); - registry.insert(); + registry.insert< + mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, + mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, + mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, + mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect, + mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, + mlir::triton::nvgpu::NVGPUDialect, + mlir::triton::amdgpu::TritonAMDGPUDialect, + mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect, + mlir::ttx::TritonTilingExtDialect, mlir::tts::TritonStructuredDialect, + mlir::linalg::LinalgDialect, mlir::func::FuncDialect, + mlir::tensor::TensorDialect, mlir::memref::MemRefDialect, + mlir::affine::AffineDialect, mlir::bufferization::BufferizationDialect, + mlir::mk::MagicKernelDialect, mlir::tx::Tx81Dialect>(); } diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt index a373753a2..4e927b5e4 100644 --- a/third_party/tsingmicro/crt/CMakeLists.txt +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -27,29 +27,31 @@ if(NOT DEFINED LLVM_SYSPATH) endif() endif() -if(NOT DEFINED TX8_HOME) - if(DEFINED ENV{TX8_HOME}) - set(TX8_HOME $ENV{TX8_HOME}) +if(NOT DEFINED TX8_DEPS_ROOT) + if(DEFINED ENV{TX8_DEPS_ROOT}) + set(TX8_DEPS_ROOT $ENV{TX8_DEPS_ROOT}) else() - message(FATAL_ERROR "TX8_HOME environment variable is not defined") + message(FATAL_ERROR "TX8_DEPS_ROOT environment variable is not defined") + endif() +endif() + +# Build for simulator or hardware +if(NOT DEFINED USE_SIM_MODE) + if(DEFINED ENV{USE_SIM_MODE}) + set(USE_SIM_MODE $ENV{USE_SIM_MODE}) + else() + set(USE_SIM_MODE OFF) + message(STATUS "Building for hardware (USE_SIM_MODE not set)") endif() endif() # Project name and version project(VendorRuntime LANGUAGES CXX C) -# Define RISC-V target triple -set(RISCV_TRIPLE "riscv64-unknown-elf") -set(CMAKE_SYSTEM_NAME Generic) -set(CMAKE_SYSTEM_PROCESSOR riscv) -set(CMAKE_C_COMPILER ${LLVM_SYSPATH}/bin/clang) -set(CMAKE_CXX_COMPILER ${LLVM_SYSPATH}/bin/clang++) - # Define standard include directories +include_directories(${TX8_DEPS_ROOT}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/${TARGET}) -include_directories(${TX8_HOME}/include) -include_directories(${TX8_HOME}/${XUANTIE_NAME}/riscv64-unknown-elf/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}) # Set build type default @@ -63,67 +65,61 @@ set(VENDOR_RUNTIME_LIB vr) # Collect all source files from the vendor directory file(GLOB_RECURSE VENDOR_SOURCES lib/${TARGET}/*.c) -# Define RISC-V specific compile options -set(RISCV_COMPILE_OPTIONS - --target=${RISCV_TRIPLE} - -march=rv64gc - -mabi=lp64d - -mcmodel=medany -) +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_C_COMPILER ${LLVM_SYSPATH}/bin/clang) +set(CMAKE_CXX_COMPILER ${LLVM_SYSPATH}/bin/clang++) -# Add the library target -add_library(${VENDOR_RUNTIME_LIB} STATIC ${VENDOR_SOURCES}) +if (USE_SIM_MODE) + # Define simulator specific compile options + set(SIMULATOR_COMPILE_OPTIONS + -fPIC + -DUSE_SIM_MODE + ) + + add_library(${VENDOR_RUNTIME_LIB} SHARED ${VENDOR_SOURCES}) -# Apply RISC-V specific settings to our target -target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${RISCV_COMPILE_OPTIONS}) -target_link_options(${VENDOR_RUNTIME_LIB} PRIVATE --target=${RISCV_TRIPLE}) + # Apply simulator specific settings to our target + target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${SIMULATOR_COMPILE_OPTIONS}) -if (DEFINED TSM_BACKEND_DIR) + # Set properties for the library set_target_properties(${VENDOR_RUNTIME_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON - ARCHIVE_OUTPUT_DIRECTORY ${TSM_BACKEND_DIR}/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib + SUFFIX ".so" ) -else () +else() + # Define RISC-V target triple + set(RISCV_TRIPLE "riscv64-unknown-elf") + set(CMAKE_SYSTEM_PROCESSOR riscv) + + include_directories(${TX8_DEPS_ROOT}/include) + include_directories(${TX8_DEPS_ROOT}/${XUANTIE_NAME}/riscv64-unknown-elf/include) + + # Define RISC-V specific compile options + set(RISCV_COMPILE_OPTIONS + --target=${RISCV_TRIPLE} + -march=rv64gc + -mabi=lp64d + -mcmodel=medany + -DCONFIG_TX8_KERNEL_PRINTF_SUPPORT=1 + ) + + # Add the library target + add_library(${VENDOR_RUNTIME_LIB} STATIC ${VENDOR_SOURCES}) + + # Apply RISC-V specific settings to our target + target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${RISCV_COMPILE_OPTIONS}) + target_link_options(${VENDOR_RUNTIME_LIB} PRIVATE --target=${RISCV_TRIPLE}) + # Set properties for the library set_target_properties(${VENDOR_RUNTIME_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON - ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib + SUFFIX ".a" ) -endif () -# Setup compiler and environment for RISC-V compilation -if(CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") - # Use the existing Clang installation with target triple - message(STATUS "Using Clang with RISC-V target triple") -else() - # Override compiler paths if using explicit RISC-V toolchain - message(STATUS "Setting explicit RISC-V compiler from LLVM_SYSPATH") - - foreach(source ${VENDOR_SOURCES}) - if(source MATCHES "\\.(c)$") - set_source_files_properties(${source} PROPERTIES - COMPILE_FLAGS "-xc --target=${RISCV_TRIPLE}" - LANGUAGE C) - elseif(source MATCHES "\\.(cpp)$") - set_source_files_properties(${source} PROPERTIES - COMPILE_FLAGS "-xc++ --target=${RISCV_TRIPLE}" - LANGUAGE CXX) - endif() - endforeach() - - # Set compiler launch commands for the target - add_custom_command(TARGET ${VENDOR_RUNTIME_LIB} PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E echo "Building ${VENDOR_RUNTIME_LIB} for RISC-V target" - ) endif() -# Install targets install(TARGETS ${VENDOR_RUNTIME_LIB} - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib - RUNTIME DESTINATION bin + ARCHIVE DESTINATION ${INSTALL_TSINGMICRO_DIR}/lib ) - -# Install headers (optional) -file(GLOB_RECURSE VENDOR_HEADERS Target/lib/${TARGET}/*.h) -install(FILES ${VENDOR_HEADERS} DESTINATION include/${TARGET}) diff --git a/third_party/tsingmicro/crt/gcc_flash_smartl.ld b/third_party/tsingmicro/crt/gcc_flash_smartl.ld deleted file mode 100644 index 6786fb002..000000000 --- a/third_party/tsingmicro/crt/gcc_flash_smartl.ld +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Copyright (C) 2017-2024 Alibaba Group Holding Limited - */ - -/****************************************************************************** - * @file gcc_csky.ld - * @brief csky linker file - * @version V1.0 - * @date 02. June 2017 - ******************************************************************************/ -MEMORY -{ - ISRAM : ORIGIN = 0x00000000 , LENGTH = 0x20000 /* ISRAM 128KB*/ - DSRAM : ORIGIN = 0x20000000 , LENGTH = 0x80000 /* DSRAM 512KB*/ -} - -__min_heap_size = 0x200; -PROVIDE (__ram_end = 0x20020000); -PROVIDE (__heap_end = __ram_end); - -REGION_ALIAS("REGION_TEXT", ISRAM); -REGION_ALIAS("REGION_RODATA", ISRAM); -REGION_ALIAS("REGION_DATA", DSRAM); -REGION_ALIAS("REGION_BSS", DSRAM); - -ENTRY(Reset_Handler) -SECTIONS -{ - .text : { - . = ALIGN(0x8) ; - __stext = . ; - KEEP(*startup.o(*.text)) - KEEP(*startup.o(*.vectors)) - KEEP(*vectors.o(*.text)) - KEEP(*(.text.entry)) - *(.text*) - *(.gnu.warning) - *(.stub) - *(.gnu.linkonce.t*) - *(.glue_7t) - *(.glue_7) - *(.jcr) - KEEP (*(.init)) - KEEP (*(.fini)) - . = ALIGN (0x4) ; - PROVIDE(__ctbp = .); - *(.call_table_data) - *(.call_table_text) - . = ALIGN(0x10) ; - __etext = . ; - } > REGION_TEXT - .rodata : { - . = ALIGN(0x8) ; - __srodata = .; - *(.rdata) - *(.rdata*) - *(.rdata1) - *(.rdata.*) - *(.rodata*) - *(.srodata*) - . = ALIGN(0x8) ; - __init_array_start = .; - __ctors_start__ = .; - KEEP (*(SORT(.init_array.*))) - KEEP (*(.init_array)) - __init_array_end = .; - __ctors_end__ = .; - - __fini_array_start = .; - __dtors_start__ = .; - KEEP (*(SORT(.fini_array.*))) - KEEP (*(.fini_array)) - __fini_array_end = .; - __dtors_end__ = .; - . = ALIGN(0x8) ; - - __ctor_start__ = .; - KEEP (*(SORT(.ctors.*))) - KEEP (*(.ctors)) - __ctor_end__ = .; - KEEP (*(SORT(.dtors.*))) - KEEP (*(.dtors)) - __dtor_end__ = .; - . = ALIGN(0x8) ; -/*****************************************/ - /* section information for finsh shell */ - . = ALIGN(0x8); - __fsymtab_start = .; - KEEP(*(FSymTab)) - __fsymtab_end = .; - . = ALIGN(0x8); - __vsymtab_start = .; - KEEP(*(VSymTab)) - __vsymtab_end = .; - . = ALIGN(0x8); - - /* section information for initial. */ - __rt_init_start = .; - KEEP(*(SORT(.rti_fn*))) - __rt_init_end = .; - . = ALIGN(0x8) ; - - /* section information for at utest */ - __rt_utest_tc_tab_start = .; - KEEP(*(UtestTcTab)) - __rt_utest_tc_tab_end = .; - . = ALIGN(0x8); - - /* section information for at server */ - . = ALIGN(0x8); - __rtatcmdtab_start = .; - KEEP(*(RtAtCmdTab)) - __rtatcmdtab_end = .; - . = ALIGN(0x8); - - /* section information for modules */ - . = ALIGN(0x8); - __rtmsymtab_start = .; - KEEP(*(RTMSymTab)) - __rtmsymtab_end = .; - - /* section information for uPRC */ - . = ALIGN(0x8); - __uRPCSvcTab_start = .; - KEEP(*(uRPCSvcTab)) - __uRPCSvcTab_end = .; - - /* section information for var export */ - . = ALIGN(0x8); - __ve_table_start = .; - KEEP(*(SORT(*.VarExpTab.*))) - __ve_table_end = .; -/*****************************************/ -/************** added drivers **************/ - _cli_region_begin = .; - KEEP(*(CliRegion)) - . = ALIGN(0x8); - _cli_region_end = .; - - __core_driver_start__ = .; - KEEP(*(.core_driver_entry)) - . = ALIGN(0x8); - __core_driver_end__ = .; - - __bus_driver_start__ = .; - KEEP(*(*.bus_driver_entry)) - __bus_driver_end__ = .; - - __early_driver_start__ = .; - KEEP(*(*.early_driver_entry)) - __early_driver_end__ = .; - - __vfs_driver_start__ = .; - KEEP(*(*.vfs_driver_entry)) - __vfs_driver_end__ = .; - - __level0_driver_start__ = .; - KEEP(*(*.level0_driver_entry)) - __level0_driver_end__ = .; - - __level1_driver_start__ = .; - KEEP(*(*.level1_driver_entry)) - __level1_driver_end__ = .; - - __level2_driver_start__ = .; - KEEP(*(*.level2_driver_entry)) - __level2_driver_end__ = .; - - __level3_driver_start__ = .; - KEEP(*(*.level3_driver_entry)) - __level3_driver_end__ = .; - - __post_driver_start__ = .; - KEEP(*(*.post_driver_entry)) - __post_driver_end__ = .; -/************** end of drivers *********/ - . = ALIGN(0x8) ; - __erodata = .; - __rodata_end__ = .; - } > REGION_RODATA - .data : { - . = ALIGN(0x8) ; - __sdata = . ; - __data_start__ = . ; - data_start = . ; - *(.got.plt) - *(.got) - *(.gnu.linkonce.r*) - *(.data*) - *(.gnu.linkonce.d*) - *(.gcc_except_table*) - __start_init_call = .; - *(.initcall.init) - __stop_init_call = .; - __start_cmd = .; - *(.bootloaddata.cmd) - . = ALIGN(0x8) ; - __stop_cmd = .; - __global_pointer$ = .; - *(.sdata) - *(.sdata.*) - *(.sdata2.*) - *(.gnu.linkonce.s.*) - *(__libc_atexit) - *(__libc_subinit) - *(__libc_subfreeres) - *(.note.ABI-tag) - __edata = .; - __data_end__ = .; - . = ALIGN(0x8) ; - } > REGION_DATA AT > REGION_RODATA - ._ram_code : { - . = ALIGN(0x8) ; - __ram_code_start__ = .; - *(.ram.code*) - . = ALIGN(0x8) ; - __ram_code_end__ = .; - } > REGION_DATA AT > REGION_RODATA - .bss : { - . = ALIGN(0x8) ; - __sbss = ALIGN(0x8) ; - __bss_start__ = . ; - *(.dynsbss) - *(.sbss) - *(.sbss.*) - *(.scommon) - *(.dynbss) - *(.bss*) - *(COMMON) - . = ALIGN(0x8) ; - __ebss = . ; - __bss_end__ = .; - __end = . ; - end = . ; - } > REGION_BSS AT > REGION_BSS - ._user_heap (NOLOAD): { - . = ALIGN(0x8) ; - *(.stack*) - . = ALIGN(0x8) ; - __heap_start = .; - . += __min_heap_size; - . = ALIGN(0x8) ; - } > REGION_BSS AT > REGION_BSS -} diff --git a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld deleted file mode 100644 index db5c48fc0..000000000 --- a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Copyright (C) 2017-2024 Alibaba Group Holding Limited - */ - -MEMORY -{ - DRAM : ORIGIN = 0x50000000, LENGTH = 0x100000 /* on-chip DRAM 1*1MB */ -} - -__min_heap_size = 0x200; -PROVIDE (__ram_end = 0x50100000 - 0x8); -PROVIDE (__heap_end = __ram_end); - -REGION_ALIAS("REGION_TEXT", DRAM); -REGION_ALIAS("REGION_RODATA", DRAM); -REGION_ALIAS("REGION_DATA", DRAM); -REGION_ALIAS("REGION_BSS", DRAM); - -ENTRY(Reset_Handler) -SECTIONS -{ - .text : { - . = ALIGN(0x8) ; - __stext = . ; - KEEP(*startup.o(*.text)) - KEEP(*startup.o(*.vectors)) - KEEP(*vectors.o(*.text)) - KEEP(*(.text.entry)) - *(.text) - *(.text*) - *(.text.*) - *(.gnu.warning) - *(.stub) - *(.gnu.linkonce.t*) - *(.glue_7t) - *(.glue_7) - *(.jcr) - KEEP (*(.init)) - KEEP (*(.fini)) - . = ALIGN(0x8) ; - PROVIDE(__ctbp = .); - *(.call_table_data) - *(.call_table_text) - . = ALIGN(0x8) ; - __etext = . ; - } > REGION_TEXT - .gcc_except_table : ONLY_IF_RO { - *(.gcc_except_table .gcc_except_table.*) - } > REGION_TEXT - .rodata : { - . = ALIGN(0x8) ; - __srodata = .; - *(.rdata) - *(.rdata*) - *(.rdata1) - *(.rdata.*) - *(.rodata) - *(.rodata1) - *(.rodata*) - *(.rodata.*) - *(.rodata.str1.4) - *(.srodata*) - . = ALIGN(0x8) ; - - __init_array_start = .; - __ctors_start__ = .; - KEEP (*(SORT(.init_array.*))) - KEEP (*(.init_array)) - __init_array_end = .; - __ctors_end__ = .; - - __fini_array_start = .; - __dtors_start__ = .; - KEEP (*(SORT(.fini_array.*))) - KEEP (*(.fini_array)) - __fini_array_end = .; - __dtors_end__ = .; - - __ctor_start__ = .; - KEEP (*(SORT(.ctors.*))) - KEEP (*(.ctors)) - __ctor_end__ = .; - KEEP (*(SORT(.dtors.*))) - KEEP (*(.dtors)) - __dtor_end__ = .; - . = ALIGN(0x8) ; -/*****************************************/ - /* section information for finsh shell */ - . = ALIGN(0x8); - __fsymtab_start = .; - KEEP(*(FSymTab)) - __fsymtab_end = .; - . = ALIGN(0x8); - __vsymtab_start = .; - KEEP(*(VSymTab)) - __vsymtab_end = .; - . = ALIGN(0x8); - - /* section information for initial. */ - __rt_init_start = .; - KEEP(*(SORT(.rti_fn*))) - __rt_init_end = .; - . = ALIGN(0x8) ; - - /* section information for at utest */ - __rt_utest_tc_tab_start = .; - KEEP(*(UtestTcTab)) - __rt_utest_tc_tab_end = .; - . = ALIGN(0x8); - - /* section information for at server */ - . = ALIGN(0x8); - __rtatcmdtab_start = .; - KEEP(*(RtAtCmdTab)) - __rtatcmdtab_end = .; - . = ALIGN(0x8); - - /* section information for modules */ - . = ALIGN(0x8); - __rtmsymtab_start = .; - KEEP(*(RTMSymTab)) - __rtmsymtab_end = .; - - /* section information for uPRC */ - . = ALIGN(0x8); - __uRPCSvcTab_start = .; - KEEP(*(uRPCSvcTab)) - __uRPCSvcTab_end = .; - - /* section information for var export */ - . = ALIGN(0x8); - __ve_table_start = .; - KEEP(*(SORT(*.VarExpTab.*))) - __ve_table_end = .; -/*****************************************/ -/************** added drivers **************/ - _cli_region_begin = .; - KEEP(*(CliRegion)) - . = ALIGN(0x8) ; - _cli_region_end = .; - - __core_driver_start__ = .; - KEEP(*(.core_driver_entry)) - . = ALIGN(0x8) ; - __core_driver_end__ = .; - - __bus_driver_start__ = .; - KEEP(*(*.bus_driver_entry)) - __bus_driver_end__ = .; - - __early_driver_start__ = .; - KEEP(*(*.early_driver_entry)) - __early_driver_end__ = .; - - __vfs_driver_start__ = .; - KEEP(*(*.vfs_driver_entry)) - __vfs_driver_end__ = .; - - __level0_driver_start__ = .; - KEEP(*(*.level0_driver_entry)) - __level0_driver_end__ = .; - - __level1_driver_start__ = .; - KEEP(*(*.level1_driver_entry)) - __level1_driver_end__ = .; - - __level2_driver_start__ = .; - KEEP(*(*.level2_driver_entry)) - __level2_driver_end__ = .; - - __level3_driver_start__ = .; - KEEP(*(*.level3_driver_entry)) - __level3_driver_end__ = .; - - __post_driver_start__ = .; - KEEP(*(*.post_driver_entry)) - __post_driver_end__ = .; -/************** end of drivers *********/ - . = ALIGN(0x8) ; - __erodata = .; - __rodata_end__ = .; - } > REGION_RODATA - .data : { - . = ALIGN(0x8) ; - __sdata = . ; - __data_start__ = . ; - data_start = . ; - *(.got.plt) - *(.got) - *(.gnu.linkonce.r*) - *(.data) - *(.data*) - *(.data1) - *(.data.*) - *(.gnu.linkonce.d*) - *(.data1) - *(.gcc_except_table) - *(.gcc_except_table*) - __start_init_call = .; - *(.initcall.init) - __stop_init_call = .; - __start_cmd = .; - *(.bootloaddata.cmd) - . = ALIGN(0x8) ; - __stop_cmd = .; - __global_pointer$ = .; - *(.sdata) - *(.sdata.*) - *(.sdata2.*) - *(.gnu.linkonce.s.*) - *(__libc_atexit) - *(__libc_subinit) - *(__libc_subfreeres) - *(.note.ABI-tag) - __edata = .; - __data_end__ = .; - . = ALIGN(0x8) ; - } > REGION_DATA - .gcc_except_table : ONLY_IF_RW { - *(.gcc_except_table .gcc_except_table.*) - __edata = .; - __data_end__ = .; - } > REGION_DATA - .bss : { - . = ALIGN(0x8) ; - __sbss = ALIGN(0x8) ; - __bss_start__ = . ; - *(.dynsbss) - *(.sbss) - *(.sbss.*) - *(.scommon) - *(.dynbss) - *(.bss) - *(.bss.*) - *(COMMON) - . = ALIGN(0x8) ; - __ebss = . ; - __bss_end__ = .; - __end = . ; - end = . ; - } > REGION_BSS - ._user_heap (NOLOAD): { - . = ALIGN(0x8) ; - *(.stack*) - . = ALIGN(0x8) ; - __heap_start = .; - . += __min_heap_size; - . = ALIGN(0x8) ; - } > REGION_BSS -} diff --git a/third_party/tsingmicro/crt/gcc_tx8_smarth.ld b/third_party/tsingmicro/crt/gcc_tx8_smarth.ld deleted file mode 100644 index eb2aacb2a..000000000 --- a/third_party/tsingmicro/crt/gcc_tx8_smarth.ld +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Copyright (C) 2017-2024 Alibaba Group Holding Limited - */ - -/****************************************************************************** - * @file gcc_csky.ld - * @brief csky linker file - * @version V1.0 - * @date 02. June 2017 - ******************************************************************************/ -MEMORY -{ - mem0 (rwx) : ORIGIN = 0x00000000, LENGTH = (20*1024*1024) -} - -REGION_ALIAS("r", mem0); -REGION_ALIAS("w", mem0); -REGION_ALIAS("x", mem0); - -ENTRY(Reset_Handler) -SECTIONS -{ - -.text.startup 0x0:{ - . = ALIGN(0x8) ; - KEEP(*startup.o(*.text.startup)) - *(.text.startup) - } > x - - .text : { - . = ALIGN(0x8) ; - __ram_code_start__ = .; - __stext = . ; - KEEP(*startup.o(*.text)) - KEEP(*startup.o(*.vectors)) - KEEP(*vectors.o(*.text)) - KEEP(*(.text.entry)) - *(.text) - *(.vectors) - *(.text*) - *(.text.*) - *(.gnu.warning) - *(.stub) - *(.gnu.linkonce.t*) - *(.glue_7t) - *(.glue_7) - *(.jcr) - KEEP (*(.init)) - KEEP (*(.fini)) - . = ALIGN(0x8) ; - PROVIDE(__ctbp = .); - *(.call_table_data) - *(.call_table_text) - . = ALIGN(0x8) ; - __etext = . ; - __ram_code_end__ = .; - } > x - - .gcc_except_table : ONLY_IF_RO { - *(.gcc_except_table .gcc_except_table.*) - } > x - - .rodata : { - . = ALIGN(0x8) ; - __srodata = .; - *(.rdata) - *(.rdata*) - *(.rdata1) - *(.rdata.*) - *(.rodata) - *(.rodata1) - *(.rodata*) - *(.rodata.*) - *(.rodata.str1.4) - *(.srodata*) - . = ALIGN(0x8) ; - - __init_array_start = .; - __ctors_start__ = .; - KEEP (*(SORT(.init_array.*))) - KEEP (*(.init_array)) - __init_array_end = .; - __ctors_end__ = .; - - __fini_array_start = .; - __dtors_start__ = .; - KEEP (*(SORT(.fini_array.*))) - KEEP (*(.fini_array)) - __fini_array_end = .; - __dtors_end__ = .; - - __ctor_start__ = .; - KEEP (*(SORT(.ctors.*))) - KEEP (*(.ctors)) - __ctor_end__ = .; - KEEP (*(SORT(.dtors.*))) - KEEP (*(.dtors)) - __dtor_end__ = .; - . = ALIGN(0x8) ; -/*****************************************/ - /* section information for finsh shell */ - . = ALIGN(0x8); - __fsymtab_start = .; - KEEP(*(FSymTab)) - __fsymtab_end = .; - . = ALIGN(0x8); - __vsymtab_start = .; - KEEP(*(VSymTab)) - __vsymtab_end = .; - . = ALIGN(0x8); - - /* section information for initial. */ - __rt_init_start = .; - KEEP(*(SORT(.rti_fn*))) - __rt_init_end = .; - . = ALIGN(0x8) ; - - /* section information for at utest */ - __rt_utest_tc_tab_start = .; - KEEP(*(UtestTcTab)) - __rt_utest_tc_tab_end = .; - . = ALIGN(0x8); - - /* section information for at server */ - . = ALIGN(0x8); - __rtatcmdtab_start = .; - KEEP(*(RtAtCmdTab)) - __rtatcmdtab_end = .; - . = ALIGN(0x8); - - /* section information for modules */ - . = ALIGN(0x8); - __rtmsymtab_start = .; - KEEP(*(RTMSymTab)) - __rtmsymtab_end = .; - - /* section information for uPRC */ - . = ALIGN(0x8); - __uRPCSvcTab_start = .; - KEEP(*(uRPCSvcTab)) - __uRPCSvcTab_end = .; - - /* section information for var export */ - . = ALIGN(0x8); - __ve_table_start = .; - KEEP(*(SORT(*.VarExpTab.*))) - __ve_table_end = .; -/*****************************************/ -/************** added drivers **************/ - _cli_region_begin = .; - KEEP(*(CliRegion)) - . = ALIGN(0x8) ; - _cli_region_end = .; - - __core_driver_start__ = .; - KEEP(*(.core_driver_entry)) - . = ALIGN(0x8) ; - __core_driver_end__ = .; - - __bus_driver_start__ = .; - KEEP(*(*.bus_driver_entry)) - __bus_driver_end__ = .; - - __early_driver_start__ = .; - KEEP(*(*.early_driver_entry)) - __early_driver_end__ = .; - - __vfs_driver_start__ = .; - KEEP(*(*.vfs_driver_entry)) - __vfs_driver_end__ = .; - - __level0_driver_start__ = .; - KEEP(*(*.level0_driver_entry)) - __level0_driver_end__ = .; - - __level1_driver_start__ = .; - KEEP(*(*.level1_driver_entry)) - __level1_driver_end__ = .; - - __level2_driver_start__ = .; - KEEP(*(*.level2_driver_entry)) - __level2_driver_end__ = .; - - __level3_driver_start__ = .; - KEEP(*(*.level3_driver_entry)) - __level3_driver_end__ = .; - - __post_driver_start__ = .; - KEEP(*(*.post_driver_entry)) - __post_driver_end__ = .; -/************** end of drivers *********/ - . = ALIGN(0x8) ; - __erodata = .; - __rodata_end__ = .; - } > r - - .data : { - . = ALIGN(0x8) ; - __sdata = . ; - __data_start__ = . ; - data_start = . ; - *(.got.plt) - *(.got) - *(.gnu.linkonce.r*) - *(.data) - *(.data*) - *(.data1) - *(.data.*) - *(.gnu.linkonce.d*) - *(.data1) - *(.gcc_except_table) - *(.gcc_except_table*) - __start_init_call = .; - *(.initcall.init) - __stop_init_call = .; - __start_cmd = .; - *(.bootloaddata.cmd) - . = ALIGN(0x8) ; - __stop_cmd = .; - __global_pointer$ = .; - *(.sdata) - *(.sdata.*) - *(.sdata2.*) - *(.gnu.linkonce.s.*) - *(__libc_atexit) - *(__libc_subinit) - *(__libc_subfreeres) - *(.note.ABI-tag) - __edata = .; - __data_end__ = .; - . = ALIGN(0x8) ; - } > w AT> r - - .gcc_except_table : ONLY_IF_RW { - *(.gcc_except_table .gcc_except_table.*) - __edata = .; - __data_end__ = .; - } > w AT> r - - .rela.dyn : { - . = ALIGN(0x8) ; - __rel_dyn_start = .; - *(.rela*) - __rel_dyn_end = .; - } - .dynsym : { - . = ALIGN(0x8) ; - __dyn_sym_start = .; - *(.dynsym) - __dyn_sym_end = .; - } - .bss : { - . = ALIGN(0x8) ; - __sbss = ALIGN(0x8) ; - __bss_start__ = . ; - *(.dynsbss) - *(.sbss) - *(.sbss.*) - *(.scommon) - *(.dynbss) - *(.bss) - *(.bss.*) - *(COMMON) - . = ALIGN(0x8) ; - __ebss = . ; - __bss_end__ = .; - __end = . ; - end = . ; - } > w - ._user_heap (NOLOAD): { - . = ALIGN(0x8) ; - *(.stack*) - . = ALIGN(0x8) ; - __heap_start = ABSOLUTE(.); - . = ORIGIN(w) + LENGTH(w); - __heap_end = ABSOLUTE(.); - - } > w -} diff --git a/third_party/tsingmicro/crt/include/Tx81/tx81.h b/third_party/tsingmicro/crt/include/Tx81/tx81.h index b0af0b73d..c88d3c784 100644 --- a/third_party/tsingmicro/crt/include/Tx81/tx81.h +++ b/third_party/tsingmicro/crt/include/Tx81/tx81.h @@ -9,7 +9,6 @@ #include "instr_adapter.h" #include "instr_def.h" -#include "lib_log.h" #include #include #include @@ -27,7 +26,22 @@ enum ActFuncMode : int32_t { ENLeakRelu = 2, }; -inline uint64_t spm_print_offset(uint64_t addr) { - return (uint64_t)addr + 0x030400000; +#ifdef __cplusplus +extern "C" { +#endif + +float set_value2float32(Data_Format fmt, int8_t *value); + +bool is_contiguous(int *shape, int *strides, int elem_bytes); + +// Use in simulation mode, return the spm address mapping +int8_t *get_spm_memory_mapping(uint64_t offset); +// Hardware mode will use add the spmMappingOffset to get the real spm address +// Simulation mode will call get_spm_memory_mapping +int8_t *get_spm_memory_mapping_wrapper(uint64_t offset); + +#ifdef __cplusplus } +#endif + #endif // CRT_TARGET_TX81_H diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmax.c b/third_party/tsingmicro/crt/lib/Tx81/argmax.c index a982f8ff9..353c08466 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/argmax.c +++ b/third_party/tsingmicro/crt/lib/Tx81/argmax.c @@ -11,7 +11,8 @@ #include "tx81.h" -void __ArgMax(uint64_t *src, uint32_t elem_count, uint16_t fmt) { +void __ArgMax(uint64_t *src, uint64_t *dst0, uint64_t *dst1, + uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); TsmPeripheralInstr inst = {I_CGRA, @@ -28,6 +29,11 @@ void __ArgMax(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + TsmWaitfinish(); + + *(float *)dst0 = *(float *)inst.param.wb_data0; + *(int32_t *)dst1 = *(int32_t *)inst.param.wb_data1; + // Destroy the command buffer. TsmDeletePeripheral(cmd); } diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmin.c b/third_party/tsingmicro/crt/lib/Tx81/argmin.c index 856854d3b..05dd675ff 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/argmin.c +++ b/third_party/tsingmicro/crt/lib/Tx81/argmin.c @@ -11,7 +11,8 @@ #include "tx81.h" -void __ArgMin(uint64_t *src, uint32_t elem_count, uint16_t fmt) { +void __ArgMin(uint64_t *src, uint64_t *dst0, uint64_t *dst1, + uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); TsmPeripheralInstr inst = {I_CGRA, @@ -28,6 +29,11 @@ void __ArgMin(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + TsmWaitfinish(); + + *(float *)dst0 = *(float *)inst.param.wb_data0; + *(int32_t *)dst1 = *(int32_t *)inst.param.wb_data1; + // Destroy the command buffer. TsmDeletePeripheral(cmd); } diff --git a/third_party/tsingmicro/crt/lib/Tx81/barrier.c b/third_party/tsingmicro/crt/lib/Tx81/barrier.c new file mode 100644 index 000000000..26c621776 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/barrier.c @@ -0,0 +1,14 @@ +//===------------------------ Barrier.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Barrier see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Barrier() { TsmWaitfinish(); } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c index 0a9344213..e05ff412c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c @@ -11,7 +11,8 @@ #include "tx81.h" -void __Bit2Fp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { +void __Bit2Fp(uint64_t *src, uint64_t *target, uint32_t elem_count, + uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); TsmPeripheralInstr inst = {I_CGRA, @@ -21,9 +22,10 @@ void __Bit2Fp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { { 0, }}; - ; - cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + assert(elem_count % 8 == 0); + + cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)target, elem_count, (Data_Format)fmt); // Dispatch the command to accelerator diff --git a/third_party/tsingmicro/crt/lib/Tx81/channelnorm.c b/third_party/tsingmicro/crt/lib/Tx81/channelnorm.c new file mode 100644 index 000000000..176b724a3 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/channelnorm.c @@ -0,0 +1,154 @@ +//===------------------------ channelnorm.c -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::channelnorm/dechannelnorm. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" +#include + +void __ChannelNorm(uint64_t *src, uint64_t *dst, uint16_t n, uint16_t h, + uint16_t w, uint16_t c, uint16_t c0, uint16_t bit_width) { + int calign_base = bit_width == 8 ? 128 : 64; + int dtype_size = bit_width / 8; + int cx = c / calign_base; + + TsmDataMove *dm = TsmNewDataMove(); + TsmDataMoveInstr dm_param = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + uint32_t inner_dim_size = c; + St_StrideIteration src_it = {0}, dst_it = {0}; + + // align cx + if (cx > 0) { + uint32_t elem_size = calign_base * dtype_size; + // byte number + src_it.stride0 = inner_dim_size * dtype_size; + src_it.iteration0 = h * w; + src_it.stride1 = elem_size; + src_it.iteration1 = cx; + src_it.stride2 = h * w * c * dtype_size; + src_it.iteration2 = n; + + dst_it.stride0 = elem_size; + dst_it.iteration0 = h * w; + dst_it.stride1 = dst_it.iteration0 * dst_it.stride0; + dst_it.iteration1 = cx; + dst_it.stride2 = n * h * w * elem_size; + dst_it.iteration2 = n; + + dm->GatherScatter(&dm_param, (uint64_t)src, (uint64_t)dst, elem_size, + &src_it, &dst_it); + TsmExecute(&dm_param); + } + + // align c0 + if (c0 > 0) { + uint32_t src_offset = cx * calign_base * dtype_size; + uint32_t dst_offset = cx * h * w * calign_base * dtype_size; + int32_t c0_valid = inner_dim_size - cx * calign_base; + int32_t elem_size = c0_valid * dtype_size; + + src_it.stride0 = inner_dim_size * dtype_size; + src_it.iteration0 = n * h * w; + src_it.stride1 = n * h * w * inner_dim_size * dtype_size; + src_it.iteration1 = 1; + src_it.stride2 = n * h * w * inner_dim_size * dtype_size; + src_it.iteration2 = 1; + + dst_it.stride0 = c0 * dtype_size; + dst_it.iteration0 = n * h * w; + dst_it.stride1 = n * h * w * c0 * dtype_size; + dst_it.iteration1 = 1; + dst_it.stride2 = n * h * w * c0 * dtype_size; + dst_it.iteration2 = 1; + + dm->GatherScatter(&dm_param, (uint64_t)src + src_offset, + (uint64_t)dst + dst_offset, elem_size, &src_it, &dst_it); + TsmExecute(&dm_param); + } + + TsmDeleteDataMove(dm); +} + +void __DechannelNorm(uint64_t *src, uint64_t *dst, uint16_t n, uint16_t h, + uint16_t w, uint16_t c, uint16_t c0, uint16_t bit_width) { + int calign_base = bit_width == 8 ? 128 : 64; + int dtype_size = bit_width / 8; + int cx = c / calign_base; + + TsmDataMove *dm = TsmNewDataMove(); + TsmDataMoveInstr dm_param = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + uint32_t inner_dim_size = c; + St_StrideIteration src_it = {0}, dst_it = {0}; + + // align cx + if (cx > 0) { + uint32_t elem_size = calign_base * dtype_size; + // byte number + + src_it.stride0 = h * w * elem_size; + src_it.iteration0 = cx; + src_it.stride1 = elem_size; + src_it.iteration1 = h * w; + src_it.stride2 = h * w * elem_size; + src_it.iteration2 = n; + + dst_it.stride0 = elem_size; + dst_it.iteration0 = cx; + dst_it.stride1 = inner_dim_size * dtype_size; + dst_it.iteration1 = h * w; + dst_it.stride2 = h * w * inner_dim_size * dtype_size; + dst_it.iteration2 = n; + + dm->GatherScatter(&dm_param, (uint64_t)src, (uint64_t)dst, elem_size, + &src_it, &dst_it); + TsmExecute(&dm_param); + } + + // align c0 + if (c0 > 0) { + uint32_t src_offset = cx * calign_base * dtype_size; + uint32_t dst_offset = cx * h * w * calign_base * dtype_size; + int32_t c0_valid = inner_dim_size - cx * calign_base; + int32_t elem_size = c0_valid * dtype_size; + + src_it.stride0 = c0 * dtype_size; + src_it.iteration0 = n * h * w; + src_it.stride1 = n * h * w * c0 * dtype_size; + src_it.iteration1 = 1; + src_it.stride2 = n * h * w * c0 * dtype_size; + src_it.iteration2 = 1; + + dst_it.stride0 = inner_dim_size * dtype_size; + dst_it.iteration0 = n * h * w; + dst_it.stride1 = n * h * w * inner_dim_size * dtype_size; + dst_it.iteration1 = 1; + dst_it.stride2 = n * h * w * inner_dim_size * dtype_size; + dst_it.iteration2 = 1; + + dm->GatherScatter(&dm_param, (uint64_t)src + src_offset, + (uint64_t)dst + dst_offset, elem_size, &src_it, &dst_it); + TsmExecute(&dm_param); + } + + TsmDeleteDataMove(dm); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c index 032290c1f..65f6a3683 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c @@ -11,11 +11,13 @@ #include "tx81.h" -void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, - uint32_t src_s0, uint32_t src_i0, uint32_t src_s1, - uint32_t src_i1, uint32_t src_s2, uint32_t src_i2, - uint32_t dst_s0, uint32_t dst_i0, uint32_t dst_s1, - uint32_t dst_i1, uint32_t dst_s2, uint32_t dst_i2) { +void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t bytes, + uint32_t src_strideN, uint32_t src_strideH, + uint32_t src_strideW, uint32_t src_iterN, + uint32_t src_iterH, uint32_t src_iterW, + uint32_t dst_strideN, uint32_t dst_strideH, + uint32_t dst_strideW, uint32_t dst_iterN, + uint32_t dst_iterH, uint32_t dst_iterW) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); TsmDataMoveInstr inst = {I_CGRA, @@ -26,10 +28,12 @@ void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, 0, }}; - St_StrideIteration src_si = {src_s0, src_i0, src_s1, src_i1, src_s2, src_i2}; - St_StrideIteration dst_si = {dst_s0, dst_i0, dst_s1, dst_i1, dst_s2, dst_i2}; + St_StrideIteration src_si = {src_strideW, src_iterW, src_strideH, + src_iterH, src_strideN, src_iterN}; + St_StrideIteration dst_si = {dst_strideW, dst_iterW, dst_strideH, + dst_iterH, dst_strideN, dst_iterN}; - cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, size, &src_si, + cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, bytes, &src_si, &dst_si); // Dispatch the command to accelerator diff --git a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c index ed916cb60..5738cbedf 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c @@ -11,9 +11,8 @@ #include "tx81.h" -void __Nchw2nhwc(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { +void __Nchw2nhwc(uint64_t *src, uint64_t *dst, int32_t *src_shape, + int32_t *dst_shape, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); TsmDataMoveInstr inst = {I_CGRA, @@ -24,8 +23,8 @@ void __Nchw2nhwc(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, 0, }}; - Data_Shape shape1 = {src_n, src_h, src_w, src_c}; - Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape1 = {src_shape[0], src_shape[1], src_shape[2], src_shape[3]}; + Data_Shape shape2 = {dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3]}; cmd->Nchw2nhwc(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c index 932b71599..e871bb930 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c @@ -11,9 +11,8 @@ #include "tx81.h" -void __Nhwc2nchw(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { +void __Nhwc2nchw(uint64_t *src, uint64_t *dst, int32_t *src_shape, + int32_t *dst_shape, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); TsmDataMoveInstr inst = {I_CGRA, @@ -24,8 +23,8 @@ void __Nhwc2nchw(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, 0, }}; - Data_Shape shape1 = {src_n, src_h, src_w, src_c}; - Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape1 = {src_shape[0], src_shape[1], src_shape[2], src_shape[3]}; + Data_Shape shape2 = {dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3]}; cmd->Nhwc2nchw(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/print.c b/third_party/tsingmicro/crt/lib/Tx81/print.c new file mode 100644 index 000000000..24be787ce --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/print.c @@ -0,0 +1,27 @@ +// ===------------------------ print.c ------------------------------------===// + +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. + +// ===---------------------------------------------------------------------===// + +// Enable tx8 kernel printf support + +#include "lib_log.h" +#include "tx81.h" +#include +#include +#include + +void __Print(const char *__restrict fmt, ...) { + va_list args; + va_start(args, fmt); + + // FIXME: va_list memory layout is specific to the platform. +#ifndef USE_SIM_MODE + monitor_write_log(__FILE__, __func__, __LINE__, (char *)fmt, args); +#else + vprintf(fmt, args); +#endif + va_end(args); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rdma.c b/third_party/tsingmicro/crt/lib/Tx81/rdma.c index f000df052..36077543a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rdma.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rdma.c @@ -10,15 +10,21 @@ //===----------------------------------------------------------------------===// #include "tx81.h" +#include // The arguments list is aligned with TsmConv in Tx81Ops.td -void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, - int shape_c, int stride_n, int stride_h, int stride_w, +void __Rdma(uint64_t *src, uint64_t *dst, int *src_shape, int *src_stride, + int *dst_shape, int *dst_stride, uint32_t elem_bytes, uint32_t fmt) { // Dynamic shape, kernel implementation will cause shape equal to 0 - if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + if (src_shape[0] == 0 || src_shape[1] == 0 || src_shape[2] == 0 || + src_shape[3] == 0) return; + // Inner dim must be contiguous,last stride is always 1. + assert(src_stride[3] == 1); + assert(dst_stride[3] == 1); + // Create gemm command buffer. TsmRdma *rdma = TsmNewRdma(); TsmRdmaInstr inst = {I_RDMA, @@ -29,12 +35,35 @@ void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, 0, }}; - rdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); - rdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, - shape_h, stride_n, shape_n); + if (is_contiguous(dst_shape, dst_stride, elem_bytes)) { + rdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + + rdma->ConfigStrideIteration(&inst, src_shape[3], src_stride[2], + src_shape[2], src_stride[1], src_shape[1], + src_stride[0], src_shape[0]); + TsmExecute(&inst); + TsmDeleteRdma(rdma); + return; + } + + for (int64_t i = 0; i < src_shape[0]; ++i) { + uint64_t src_ptr0 = (uint64_t)src + i * src_stride[0] * elem_bytes; + uint64_t dst_ptr0 = (uint64_t)dst + i * dst_stride[0] * elem_bytes; + + for (int64_t j = 0; j < src_shape[1]; ++j) { + uint64_t src_ptr1 = src_ptr0 + j * src_stride[1] * elem_bytes; + uint64_t dst_ptr1 = dst_ptr0 + j * dst_stride[1] * elem_bytes; - // Dispatch the command to accelerator - TsmExecute(&inst); + for (int64_t k = 0; k < src_shape[2]; ++k) { + uint64_t src_ptr2 = src_ptr1 + k * src_stride[2] * elem_bytes; + uint64_t dst_ptr2 = dst_ptr1 + k * dst_stride[2] * elem_bytes; + rdma->Rdma1d(&inst, (uint64_t)src_ptr2, (uint64_t)dst_ptr2, + src_shape[3], (Data_Format)fmt); + TsmExecute(&inst); + TsmWaitfinish(); + } + } + } // Destroy the command buffer. TsmDeleteRdma(rdma); diff --git a/third_party/tsingmicro/crt/lib/Tx81/relation.c b/third_party/tsingmicro/crt/lib/Tx81/relation.c index 10069deb6..7d1aab25d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/relation.c +++ b/third_party/tsingmicro/crt/lib/Tx81/relation.c @@ -142,3 +142,399 @@ void __BoolLessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Destroy the command buffer. TsmDeleteRelation(cmd); } + +void __EqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->EqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __UnEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->UnEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __GreaterEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->GreaterEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __GreaterVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->GreaterVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __LessEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->LessEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __LessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->LessThenVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolUnEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolUnEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessThenVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessThenVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __EqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->EqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __UnEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->UnEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __GreaterEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->GreaterEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __GreaterVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->GreaterVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __LessEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->LessEqualVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __LessThenVS(uint64_t *src0, uint32_t src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->LessThenVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/transpose.c b/third_party/tsingmicro/crt/lib/Tx81/transpose.c index 54e2ee584..0b284d1bb 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/transpose.c +++ b/third_party/tsingmicro/crt/lib/Tx81/transpose.c @@ -11,9 +11,8 @@ #include "tx81.h" -void __Transpose(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { +void __Transpose(uint64_t *src, uint64_t *dst, int32_t *src_shape, + int32_t *dst_shape, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); TsmDataMoveInstr inst = {I_CGRA, @@ -24,8 +23,8 @@ void __Transpose(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, 0, }}; - Data_Shape shape1 = {src_n, src_h, src_w, src_c}; - Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape1 = {src_shape[0], src_shape[1], src_shape[2], src_shape[3]}; + Data_Shape shape2 = {dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3]}; cmd->Transpose(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tx81.c b/third_party/tsingmicro/crt/lib/Tx81/tx81.c new file mode 100644 index 000000000..107282eac --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tx81.c @@ -0,0 +1,38 @@ +//===------------------------- tx81.c--------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +#ifdef __cplusplus +extern "C" { +#endif + +bool is_contiguous(int *shape, int *strides, int elem_bytes) { + int expected_stride = elem_bytes; + for (int i = 0; i < 4; i++) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= shape[i]; + } + return true; +} + +// Used for kcore load/store data from/to spm +const int64_t spmMappingOffset = 0x30400000; + +int8_t *get_spm_memory_mapping_wrapper(uint64_t ptr) { +#ifdef USE_SIM_MODE + return get_spm_memory_mapping(ptr); +#else + return (int8_t *)(ptr + spmMappingOffset); +#endif +} + +#ifdef __cplusplus +} +#endif diff --git a/third_party/tsingmicro/crt/lib/Tx81/wdma.c b/third_party/tsingmicro/crt/lib/Tx81/wdma.c index 93bfe6e89..68491ec8d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/wdma.c +++ b/third_party/tsingmicro/crt/lib/Tx81/wdma.c @@ -10,16 +10,21 @@ //===----------------------------------------------------------------------===// #include "tx81.h" +#include // The arguments list is aligned with TsmConv in Tx81Ops.td -void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, - int shape_c, int stride_n, int stride_h, int stride_w, +void __Wdma(uint64_t *src, uint64_t *dst, int *src_shape, int *src_stride, + int *dst_shape, int *dst_stride, uint32_t elem_bytes, uint32_t fmt) { - // Dynamic shape, kernel implementation will cause shape equal to 0 - if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + if (src_shape[0] == 0 || src_shape[1] == 0 || src_shape[2] == 0 || + src_shape[3] == 0) return; + // Inner dim must be contiguous,last stride is always 1. + assert(src_stride[3] == 1); + assert(dst_stride[3] == 1); + // Create gemm command buffer. TsmWdma *wdma = TsmNewWdma(); TsmWdmaInstr inst = {I_WDMA, @@ -30,13 +35,35 @@ void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, 0, }}; - wdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + if (is_contiguous(src_shape, src_stride, elem_bytes)) { + wdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + + wdma->ConfigStrideIteration(&inst, dst_shape[3], dst_stride[2], + dst_shape[2], dst_stride[1], dst_shape[1], + dst_stride[0], dst_shape[0]); + TsmExecute(&inst); + TsmDeleteWdma(wdma); + return; + } + + for (int64_t i = 0; i < src_shape[0]; ++i) { + uint64_t src_ptr0 = (uint64_t)src + i * src_stride[0] * elem_bytes; + uint64_t dst_ptr0 = (uint64_t)dst + i * dst_stride[0] * elem_bytes; - wdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, - shape_h, stride_n, shape_n); + for (int64_t j = 0; j < src_shape[1]; ++j) { + uint64_t src_ptr1 = src_ptr0 + j * src_stride[1] * elem_bytes; + uint64_t dst_ptr1 = dst_ptr0 + j * dst_stride[1] * elem_bytes; - // Dispatch the command to accelerator - TsmExecute(&inst); + for (int64_t k = 0; k < src_shape[2]; ++k) { + uint64_t src_ptr2 = src_ptr1 + k * src_stride[2] * elem_bytes; + uint64_t dst_ptr2 = dst_ptr1 + k * dst_stride[2] * elem_bytes; + wdma->Wdma1d(&inst, (uint64_t)src_ptr2, (uint64_t)dst_ptr2, + src_shape[3], (Data_Format)fmt); + TsmExecute(&inst); + TsmWaitfinish(); + } + } + } // Destroy the command buffer. TsmDeleteWdma(wdma); diff --git a/third_party/tsingmicro/examples/bare_matmul.py b/third_party/tsingmicro/examples/bare_matmul.py deleted file mode 100644 index 84b9c9a87..000000000 --- a/third_party/tsingmicro/examples/bare_matmul.py +++ /dev/null @@ -1,52 +0,0 @@ -# this is a benchmark which multiplies square matrices with maximum block size -# to check the performance of tl.dot operation - -import torch -import triton -import triton.language as tl -import benchmark - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - - -@triton.jit -def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): - pid_x = tl.program_id(0) # block row id - pid_y = tl.program_id(1) # block column id - - offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - - x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) - y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) - - z = tl.dot(x, y) - - tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) - - -# @benchmark.measure() -def bench_matmul(N, provider): - device = 'cpu' - dtype = torch.float32 - a = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) - b = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) - # a = torch.randn((N, N), device=device, dtype=dtype) - # b = torch.randn((N, N), device=device, dtype=dtype) - c = torch.empty((N, N), device=device, dtype=dtype) - if provider == 'torch' or provider == 'test': - c_ref = torch.matmul(a, b) - # print("====cref:",c_ref) - if provider == 'triton' or provider == 'test': - bare_matmul[(1, )](a, b, c, N, N, N, N) - if provider == 'test': - torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0) - print("expected", c_ref) - print("actual", c) - print("======test====") - - -if __name__ == "__main__": - # benchmark.select_cpu_backend() - for provider in ['test']: - bench_matmul(16, provider) diff --git a/third_party/tsingmicro/examples/test_embedding.py b/third_party/tsingmicro/examples/test_embedding.py new file mode 100644 index 000000000..22b8d1db8 --- /dev/null +++ b/third_party/tsingmicro/examples/test_embedding.py @@ -0,0 +1,94 @@ +import torch +import math + +import triton +import triton.language as tl + +import pytest +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def embedding_kernel( + out_ptr, # pointer to the output + in_ptr, # pointer to the input + weight_ptr, # pointer to the weights + N: tl.constexpr, # number of columns in X + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + out_ptr += pid * N + in_ptr += pid + + mask = tl.arange(0, BLOCK_SIZE) < N + cols = tl.arange(0, BLOCK_SIZE) + + row_idx = tl.load(in_ptr) + weight_ptr += row_idx * N + embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0) + tl.store(out_ptr + cols, embedding_weight, mask) + + +class Embedding(torch.autograd.Function): + + @staticmethod + def forward(ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + + assert not sparse, "Currently do not support sparse format" + + M = math.prod(indices.shape) + N = weight.shape[-1] + + BLOCK_SIZE = triton.next_power_of_2(N) + indices = indices.contiguous() + weight = weight.contiguous() + output = torch.empty((*indices.shape, N), device=indices.device, dtype=weight.dtype) + + output = output.to(DEVICE) + indices = indices.to(DEVICE) + weight = weight.to(DEVICE) + embedding_kernel[ + M, + ](output, indices, weight, N, BLOCK_SIZE) + output = output.to("cpu") + ctx.M = M + ctx.N = N + ctx.num_weights = weight.shape[0] + ctx.padding_idx = padding_idx + ctx.scale_grad_by_freq = scale_grad_by_freq + ctx.sparse = sparse + ctx.indices = indices + + return output + + +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse) + + +@pytest.mark.parametrize("M, N, dtype", [ # + (M, N, dtype) for M in [1152] for N in [2048] for dtype in [torch.float32] +]) +def test_embedding(M, N, dtype, device='cpu'): + torch.manual_seed(0) + + weight = torch.rand((M, N), dtype=dtype, device=device) + indices = torch.randint(0, M, [M], dtype=torch.int32, device=device) + + triton_output = embedding(weight, indices) + + # pytorch + torch_embedding = torch.nn.Embedding(M, N, _weight=weight) + torch_output = torch_embedding(indices) + + # compare + print(f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(torch_output - triton_output))}") + assert torch.allclose(triton_output, torch_output, atol=1e-5, rtol=0) + + +if __name__ == "__main__": + # benchmark.select_cpu_backend() + test_embedding(1151, 8192, torch.float32) diff --git a/third_party/tsingmicro/examples/test_layernorm.py b/third_party/tsingmicro/examples/test_layernorm.py new file mode 100644 index 000000000..f2d6b58f8 --- /dev/null +++ b/third_party/tsingmicro/examples/test_layernorm.py @@ -0,0 +1,181 @@ +# This is the Layer Norm forward pass from the Triton tutorial found here: +# https://github.com/triton-lang/triton/blob/main/python/tutorials/05-layer-norm.py + +# %% +# Motivations +# ----------- +# +# The *LayerNorm* operator was first introduced in [BA2016]_ as a way to improve the performance +# of sequential models (e.g., Transformers) or neural networks with small batch size. +# It takes a vector :math:`x` as input and produces a vector :math:`y` of the same shape as output. +# The normalization is performed by subtracting the mean and dividing by the standard deviation of :math:`x`. +# After the normalization, a learnable linear transformation with weights :math:`w` and biases :math:`b` is applied. +# The forward pass can be expressed as follows: +# +# .. math:: +# y = \frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} } * w + b +# +# where :math:`\epsilon` is a small constant added to the denominator for numerical stability. +# Let’s first take a look at the forward pass implementation. + +import torch + +import triton +import triton.language as tl +import pytest +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +class LayerNorm(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, normalized_shape, weight, bias, eps, device): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=device) + rstd = torch.empty((M, ), dtype=torch.float32, device=device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + + x_arg_dev = x_arg.to(DEVICE) + y_dev = y.to(DEVICE) + weight_dev = weight.to(DEVICE) + bias_dev = bias.to(DEVICE) + mean_dev = mean.to(DEVICE) + rstd_dev = rstd.to(DEVICE) + # enqueue kernel + # _layer_norm_fwd_fused[(M, )]( # + # x_arg, y, weight, bias, mean, rstd, # + # x_arg.stride(0), N, eps, # + # BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + + _layer_norm_fwd_fused[(M, )]( # + x_arg_dev, y_dev, weight_dev, bias_dev, mean_dev, rstd_dev, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + x = x_arg_dev.to("cpu") + y = y_dev.to("cpu") + + ctx.save_for_backward(x, weight, bias, mean, rstd) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.eps = eps + return y + + +@pytest.mark.parametrize("M, N, dtype, eps", [ # + (M, N, dtype, eps) for M in [1151] for N in [8192] for dtype in [torch.float16] for eps in [1e-5] +]) +def test_layer_norm(M, N, dtype, eps, device): + layer_norm = LayerNorm.apply + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(False) + + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps, device) + # TODO We can't compare against Torch layer_norm since it doesn't support float16 on CPU + #y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + print(y_tri) + #print(y_ref) + + # compare + #assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + + +@benchmark.measure() +def bench_layernorm(size, provider): + layer_norm = LayerNorm.apply + device = 'cpu' + eps = 1e-5 + # dtype = torch.float16 + dtype = torch.float32 + x_shape = (size, size) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(False) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps, device) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + print(y_tri) + print(y_ref) + + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + # benchmark.select_cpu_backend() + for X in [2**i for i in range(10, 13, 1)]: + for provider in ['triton']: + bench_layernorm(X, provider) diff --git a/third_party/tsingmicro/examples/test_matmul.py b/third_party/tsingmicro/examples/test_matmul.py new file mode 100644 index 000000000..078ef4a69 --- /dev/null +++ b/third_party/tsingmicro/examples/test_matmul.py @@ -0,0 +1,176 @@ +import torch + +import triton +import triton.language as tl +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, +# num_warps=8), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, +# num_warps=2), +# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, +# num_warps=2), +# ], +# key=['M', 'N', 'K'], +# ) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float32) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=128, GROUP_SIZE_M=8) + return c + + +def test_matmul(device): + torch.manual_seed(0) + rows1 = 179 + cols1 = 167 + rows2 = 167 + cols2 = 321 + a = torch.randn((rows1, cols1), device=device, dtype=torch.float32) + b = torch.randn((rows2, cols2), device=device, dtype=torch.float32) + # a = torch.full((rows1, cols1), 1, device='cpu', dtype=torch.float32) + # b = torch.full((rows2, cols2), 1, device='cpu', dtype=torch.float32) + a = a.to(DEVICE) + b = b.to(DEVICE) + triton_output = matmul(a, b) + triton_output = triton_output.to("cpu") + + torch_output = torch.matmul(a, b) + torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) + + +@benchmark.measure() +def bench_matmul(M, N, K, provider): + a = torch.randn((M, K), device='cpu', dtype=torch.float32) + b = torch.randn((K, N), device='cpu', dtype=torch.float32) + if provider == 'torch': + torch.matmul(a, b) + if provider == 'triton': + matmul(a, b) + + +if __name__ == "__main__": + # benchmark.select_cpu_backend() + for X in [128 * i for i in range(2, 7)]: + for provider in ['torch', 'triton']: + bench_matmul(X, X, X, provider) diff --git a/third_party/tsingmicro/examples/test_vec_add.py b/third_party/tsingmicro/examples/test_vec_add.py index b75f5aa42..a41a3b690 100644 --- a/third_party/tsingmicro/examples/test_vec_add.py +++ b/third_party/tsingmicro/examples/test_vec_add.py @@ -36,6 +36,9 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x: torch.Tensor, y: torch.Tensor): + output_torch = x + y + x = x.to(DEVICE) + y = y.to(DEVICE) # We need to preallocate the output. output = torch.empty_like(x) # assert x.is_cuda and y.is_cuda and output.is_cuda @@ -51,21 +54,27 @@ def add(x: torch.Tensor, y: torch.Tensor): add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. + output = output.to("cpu") + print(f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(output_torch - output))}") return output def test(device): - # torch.manual_seed(0) + torch.manual_seed(0) size = 1024 x = torch.rand(size, device="cpu") y = torch.rand(size, device="cpu") + print("x: ", x) + print("y: ", y) output_torch = x + y x = x.to(device) y = y.to(device) output_triton = add(x, y) + print("output_triton device: ", output_triton.device) # TODO: need to check some conditions otherwise the code below does not make any difference for the test - print("expected", output_torch) output_triton = output_triton.to("cpu") + print("expected", output_torch) print("actual", output_triton) print(f"The maximum difference between torch and triton is " f"{torch.max(torch.abs(output_torch - output_triton))}") @@ -78,8 +87,6 @@ def bench_vecadd(size, provider): if provider == 'torch': a + b if provider == 'triton': - a = a.to(DEVICE) - b = b.to(DEVICE) add(a, b) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h index 031593c26..765ee3628 100644 --- a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h @@ -12,6 +12,7 @@ #ifndef ZTC_CONVERSION_LINALG_TO_MK_H #define ZTC_CONVERSION_LINALG_TO_MK_H +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -32,4 +33,132 @@ std::unique_ptr> createLinalgToMKPass(); } // namespace triton } // namespace mlir +namespace { + +using namespace mlir; +using namespace triton; + +// Extract the operations from a linalg op region +template static bool checkGenericOp(linalg::GenericOp op) { + auto regionBlock = op.getBody(); + auto regionOps = llvm::map_to_vector(regionBlock->without_terminator(), + [](Operation &op) { return &op; }); + + return regionOps.size() == 1 && isa(regionOps[0]); +} + +static bool isConstantTensor(Value &v, double targetValue) { + auto fillOp = dyn_cast(v.getDefiningOp()); + if (!fillOp) { + return false; + } + + auto fillValue = fillOp.getInputs()[0]; + auto constOp = fillValue.getDefiningOp(); + if (!constOp) { + return false; + } + + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValueAsDouble() == targetValue; + } + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValue() == static_cast(targetValue); + } + + return false; +} + +// Check if the given value is a tensor filled with 0. +static bool isZeroTensor(Value &v) { return isConstantTensor(v, 0.0); } + +// Check if the given value is a tensor filled with 1. +static bool isOneTensor(Value &v) { return isConstantTensor(v, 1.0); } + +static bool matchSigmoid(linalg::GenericOp op, Value &input) { + // 1. sub (0 - x = -x) + // 2. exp (e^(-x)) + // 3. add (1 + e^(-x)) + // 4. div (1 / (1 + e(^-x))) + // We match the sigmoid pattern from down to up. + + // 1. Match div first. + if (!checkGenericOp(op)) { + return false; + } + + auto divLhs = op.getInputs()[0]; + if (!isOneTensor(divLhs)) { + return false; + } + + // 2. Match add. + auto addResult = op.getInputs()[1]; + auto addGenericOp = addResult.getDefiningOp(); + if (!addGenericOp || !checkGenericOp(addGenericOp)) { + return false; + } + + auto addLhs = addGenericOp.getInputs()[0]; + auto addRhs = addGenericOp.getInputs()[1]; + bool isAddLhsOne = isOneTensor(addLhs); + bool isAddRhsOne = isOneTensor(addRhs); + if (!isAddLhsOne && !isAddRhsOne) { + return false; + } + + // 3. Match exp. + auto expResult = isAddLhsOne ? addRhs : addLhs; + auto expGenericOp = expResult.getDefiningOp(); + if (!expGenericOp || !checkGenericOp(expGenericOp)) { + return false; + } + + // 4. Match sub. + auto subResult = expGenericOp.getInputs()[0]; + auto subGenericOp = subResult.getDefiningOp(); + if (!subGenericOp || !checkGenericOp(subGenericOp)) { + return false; + } + + auto subLhs = subGenericOp.getInputs()[0]; + if (!isZeroTensor(subLhs)) { + return false; + } + + // Set input of Sub operation to the input of the sigmoid op. + input = subGenericOp.getInputs()[1]; + + // Match sigmoid pattern successfully. + return true; +} + +struct SigmoidFusionPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + // Match sigmoid pattern + Location loc = op.getLoc(); + Value input; + if (!matchSigmoid(op, input)) { + return rewriter.notifyMatchFailure(op, "sigmoid pattern not matched"); + } + + auto dstType = cast(op.getType(0)); + auto elementType = dstType.getElementType(); + auto init = + rewriter.create(loc, dstType.getShape(), elementType); + + // Replace the div GenericOp with mk::SigmoidOp + // We can use CSE to erase other unused generic ops. + auto sigmoidOp = rewriter.replaceOpWithNewOp( + op, dstType, input, init, rewriter.getBoolAttr(false)); + + return success(); + } +}; + +} // namespace + #endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td index f49785fa5..497a3b7c3 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td @@ -49,17 +49,26 @@ class MKOp traits = []> : !listconcat(traits, [/*TensorSizeTrait, VerifyTensorLayoutsTrait*/])> { } -class MKUnElemWiseOp : MKOp { +class MKUnElemWiseOp : MKOp { let summary = "Element wise unary operation: $mnemonic"; let arguments = ( ins - AnyTensor:$src, + TensorOrMemref:$src, + // buffer for store result + Arg:$zeroes, BoolAttr:$is_atomic ); - let results = (outs AnyTensor:$dst); + let results = (outs Variadic:$dst); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getZeroesMutable(); + } + }]; + } class MKBinElemWiseOp : MKOp { @@ -246,7 +255,29 @@ def XorSumOp : MKOp<"xor_sum", [Pure]> {} // ============================================================================= def SortOp : MKOp<"sort", [Pure]> {} -def GatherOp : MKOp<"gather", [Pure]> {} +def GatherOp : MKOp<"gather", [DestinationStyleOpInterface]> { + let summary = "Gather from a tensor along a given dimension."; + + let description = [{ + TODO: It is currently one to one mapping from upper dialect tt.gather. + }]; + + let arguments = ( + ins + TensorOrMemref:$src, // input + TensorOrMemref:$indices, // indices + Arg:$dst, // output + I32Attr:$axis + ); + + let results = (outs Variadic:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getDstMutable(); + } + }]; +} // ============================================================================= @@ -280,5 +311,37 @@ def SqrtRnOp : MKUnElemWiseOp<"sqrt_rn">; def XorOp : MKBinElemWiseOp<"xor">; // def UmulhiOp : MKOp<"umulhi", [Pure]> {} +def BarrierOp : MKOp<"barrier"> { + let summary = "Synchronizes all work items"; + let description = [{ + The "barrier" op synchronizes all work items. + }]; + let assemblyFormat = "attr-dict"; +} + +def PrintOp : MKOp<"print", [DestinationStyleOpInterface]> { + let summary = "Print at most a single scalar or 1D TensorOrMemref on each line"; + + let description = [{ + It only takes a single scalar or 1D TensorOrMemref element. + }]; + + let arguments = (ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$val, + DenseI32ArrayAttr:$isSigned + ); + + let results = (outs Variadic:$dst); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getValMutable(); + } + }]; + + let hasVerifier = 1; +} #endif // MAGIC_KERNEL_OPS diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h index 82214faf0..86442d563 100644 --- a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h @@ -13,12 +13,14 @@ #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "triton-shared/Analysis/PtrAnalysis.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "utils/FusionHelper.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" @@ -29,7 +31,6 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" -#include #include #include #include @@ -770,12 +771,28 @@ struct TransposeConverter : public OpConversionPattern { LogicalResult matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto src = adaptor.getSrc(); - auto srcRank = cast(src.getType()).getRank(); - assert(srcRank == 2 && "only expect transposing 2D data"); + auto source = adaptor.getSrc(); - auto res = getTransposedValue(src, op.getLoc(), rewriter); - rewriter.replaceOp(op, res); + auto sourceType = cast(source.getType()); + auto sourceShape = sourceType.getShape(); + auto sourceRank = sourceType.getRank(); + + auto order = op.getOrder(); + SmallVector perm(order.begin(), order.end()); + + SmallVector transposedShape(sourceType.getShape()); + for (int i = 0; i < sourceRank; i++) + transposedShape[i] = sourceShape[perm[i]]; + + Value transposeInit = rewriter.create( + op->getLoc(), transposedShape, sourceType.getElementType()); + + Value transpose = rewriter + .create(op->getLoc(), source, + transposeInit, perm) + .getResults()[0]; + + rewriter.replaceOp(op, transpose); return success(); } }; @@ -840,7 +857,8 @@ struct AssertConverter : public OpConversionPattern { condVal = newCond.getResult(); } - auto assertMessage = llvm::formatv("FIXME: assertion!"); + auto assertMessage = + llvm::formatv("Assertion `{0}` failed", op.getMessage()); rewriter.create(op.getLoc(), condVal, assertMessage.str()); @@ -849,6 +867,73 @@ struct AssertConverter : public OpConversionPattern { } }; +// FIXME: There is no triton::BarrierOp currently. +struct BarrierConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + rewriter.create(loc); + rewriter.eraseOp(op); + return success(); + } +}; + +// Similar with triton-cpu. +struct PrintOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + // If the op has no operands, we can just print the prefix. + if (op.getNumOperands() == 0) { + rewriter.create(loc, TypeRange{}, op.getPrefix(), + op.getHex(), ValueRange{}, + llvm::SmallVector{}); + rewriter.eraseOp(op); + return success(); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + Value operand = op.getOperands()[i]; + auto isSigned = {op.getIsSigned()[i]}; + // If the operand is not a ranked tensor, we should create a new tensor. + // See mlir/lib/Interfaces/DestinationStyleOpInterface.cpp#L39 + if (!isa(operand.getType())) { + auto operandTensor = rewriter.create( + loc, RankedTensorType::get({}, operand.getType()), operand); + rewriter.create(loc, operandTensor.getType(), + op.getPrefix(), op.getHex(), + operandTensor.getResult(), isSigned); + continue; + } + + auto operandType = cast(operand.getType()); + auto flattenTensor = operand; + if (operandType.getRank() != 1) { + SmallVector flatten_shape = {operandType.getNumElements()}; + auto targetType = + RankedTensorType::get(flatten_shape, operandType.getElementType()); + auto shapeAttr = rewriter.getI64TensorAttr(flatten_shape); + auto shapeConst = rewriter.create(loc, shapeAttr); + flattenTensor = rewriter.create( + loc, targetType, flattenTensor, shapeConst); + } + + rewriter.create(loc, operandType, op.getPrefix(), + op.getHex(), flattenTensor, isSigned); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct BitcastConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -954,15 +1039,24 @@ struct ClampConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL; - assert(!propagateNan && "PropagateNan is not supported"); - Location loc = op.getLoc(); Value x = adaptor.getOperands()[0]; Value min = adaptor.getOperands()[1]; Value max = adaptor.getOperands()[2]; - Value maxMin = rewriter.create(loc, x, min); - Value clamp = rewriter.create(loc, maxMin, max); + Value clamp = x; + auto maxMin = min; + + if (propagateNan) { + // Handle NaN propagation + maxMin = rewriter.create(loc, x, min); + clamp = rewriter.create(loc, maxMin, max); + } else { + // No NaN propagation. + maxMin = rewriter.create(loc, x, min); + clamp = rewriter.create(loc, maxMin, max); + } + rewriter.replaceOp(op, clamp); return success(); @@ -1152,7 +1246,7 @@ struct MatmulConverter : public OpConversionPattern { auto dstType = cast(op.getType()); auto elementType = dstType.getElementType(); bool integers = elementType.isInteger(); - + bool skipC = isZeroTensor(opc, integers); auto init = rewriter.create(loc, dstType.getShape(), elementType); TypedAttr constantAttr = @@ -1167,11 +1261,20 @@ struct MatmulConverter : public OpConversionPattern { rewriter.create(loc, ValueRange{zero}, ValueRange{init}) .result(); - auto dotOp = rewriter.create(loc, dstType, - ValueRange{opa, opb, opc, zeroes}); + auto res = rewriter + .create(loc, ValueRange{opa, opb}, + ValueRange{zeroes}) + .getResult(0); - rewriter.replaceOp(op, dotOp); + if (!skipC) { + if (integers) { + res = rewriter.create(loc, opc, res); + } else { + res = rewriter.create(loc, opc, res); + } + } + rewriter.replaceOp(op, res); return success(); } }; @@ -1189,8 +1292,8 @@ struct ReduceConverter : public OpConversionPattern { bool isReductionOpSupported(Operation *redOp) const { return isa( - redOp); + arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp, + arith::XOrIOp>(redOp); } arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, @@ -1229,6 +1332,9 @@ struct ReduceConverter : public OpConversionPattern { .Case([&](arith::MaxUIOp) { return rewriter.getIntegerAttr(constantType, 0); }) + .Case([&](arith::XOrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) .Default([](Operation *op) { op->dump(); llvm_unreachable("Reduction op not yet supported"); @@ -1242,7 +1348,7 @@ struct ReduceConverter : public OpConversionPattern { bool requiresF32Conversion(const Type elemType, Operation *redOp) const { return isa(elemType) && elemType.getIntOrFloatBitWidth() < - llvm::cast(Float32Type::get(elemType.getContext())) + cast(Float32Type::get(elemType.getContext())) .getWidth() && isa(redOp); } @@ -1260,9 +1366,10 @@ struct ReduceConverter : public OpConversionPattern { }) .Case([&](auto redOp) { - return b.create(loc, lhs, rhs); - }) + arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp, arith::XOrIOp>( + [&](auto redOp) { + return b.create(loc, lhs, rhs); + }) .Default([](Operation *op) { op->dump(); llvm_unreachable("Reduction op not yet supported"); @@ -1380,129 +1487,6 @@ template class ArgMinMaxBaseConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - // We're looking for an op that looks like this: - // - // %9:2 = "tt.reduce"(%8, %3) <{axis = 0 : i32}> ({ - // ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): - // ------------------------------------------------- - // `matchTieBreakValue` | - // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 | - // %12 = arith.cmpi slt, %arg10, %arg12 : i32 | 1. - // %13 = arith.andi %11, %12 : i1 | - // ------------------------------------------------- |-> `matchShouldUpdate` - // `matchUpdateCondition` | - // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 | 2. - // ------------------------------------------------- | - // %15 = arith.ori %14, %13 : i1 | - // ------------------------------------------------- - // %16 = arith.select %15, %arg9, %arg11 : f32 - // %17 = arith.select %15, %arg10, %arg12 : i32 - // tt.reduce.return %16, %17 : f32, i32 - // }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) - // - // The above mlir code is lowered from this combinator in triton's - // standard.py: - // - // def _argmax_combine(value1, index1, value2, index2, tie_break_left): - // if tie_break_left: - // tie = value1 == value2 and index1 < index2 - // else: - // tie = False - // gt = value1 > value2 or tie - // v_ret = core.where(gt, value1, value2) - // i_ret = core.where(gt, index1, index2) - // return v_ret, i_ret - - LogicalResult matchTieBreakResult(Value currValue, Value currIndex, - Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, - Value &tileBreakValue) const { - // Match the following (section 1. of the above) - // - // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - // %12 = arith.cmpi slt, %arg10, %arg12 : i32 - // %13 = arith.andi %11, %12 : i1 - // - // which is equivalent to the following python code - // - // tie = value1 == value2 and index1 < index2 - - // matching: %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto eqCmpOp = dyn_cast(*it++); - if (eqCmpOp) { - if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ) { - return failure(); - } - if (currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { - return failure(); - } - } else { - return failure(); - } - - // matching: %12 = arith.cmpi slt, %arg10, %arg12 : i32 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto sltCmpOp = dyn_cast(*it++); - if (sltCmpOp) { - if (sltCmpOp.getPredicate() != arith::CmpIPredicate::slt) { - return failure(); - } - if (currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { - return failure(); - } - } else { - return failure(); - } - - // matching: %13 = arith.andi %11, %12 : i1 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto andOp = dyn_cast(*it++); - if (andOp) { - if (andOp.getLhs() != eqCmpOp || andOp.getRhs() != sltCmpOp) { - return failure(); - } - } else { - return failure(); - } - - tileBreakValue = andOp; - return success(); - } - - LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, - Value reduceValue, Value reduceIndex, - mlir::Block::iterator &it, - Value &shouldUpdate) const { - Value tieResult; - if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, - reduceIndex, it, tieResult))) { - LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); - return failure(); - } - - Value comparisonResult; - if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, - reduceIndex, it, comparisonResult))) { - LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); - return failure(); - } - - // matching: %15 = arith.ori %14, %13 : i1 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto orOp = dyn_cast(*it++); - if (orOp) { - if (orOp.getLhs() != comparisonResult || orOp.getRhs() != tieResult) { - return failure(); - } - } else { - return failure(); - } - - shouldUpdate = orOp; - return success(); - } - Value getInitTensor(ConversionPatternRewriter &rewriter, ArrayRef shape, Value fillValue, Location loc) const { @@ -1516,6 +1500,7 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { public: ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} + bool isArgMin; LogicalResult match(ReduceOp op) const override final { if (op.getBody()->getNumArguments() != 4) { @@ -1531,35 +1516,9 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { Value reduceIndex = block->getArgument(3); auto opsIt = ops.begin(); - Value shouldUpdate; - if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, - reduceIndex, opsIt, shouldUpdate))) { - return failure(); - } - - // matching: %16 = arith.select %15, %arg9, %arg11 : f32 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); - auto valueSelectOp = dyn_cast(*opsIt++); - if (valueSelectOp) { - if (valueSelectOp.getCondition() != shouldUpdate || - currValue != valueSelectOp.getTrueValue() || - reduceValue != valueSelectOp.getFalseValue()) { - return failure(); - } - } else { - return failure(); - } - - // matching:%17 = arith.select %15, %arg10, %arg12 : i32 - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); - auto indexSelectOp = dyn_cast(*opsIt++); - if (indexSelectOp) { - if (indexSelectOp.getCondition() != shouldUpdate || - currIndex != indexSelectOp.getTrueValue() || - reduceIndex != indexSelectOp.getFalseValue()) { - return failure(); - } - } else { + Value indexSelectOp, valueSelectOp; + if (failed(matchArgMinMax(currValue, currIndex, reduceValue, reduceIndex, + opsIt, indexSelectOp, valueSelectOp, isArgMin))) { return failure(); } @@ -1647,64 +1606,23 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { }; struct ArgMaxConverter : public ArgMinMaxBaseConverter { - static LogicalResult matchComparisonResult(Value currValue, Value currIndex, - Value reduceValue, - Value reduceIndex, - mlir::Block::iterator &it, - Value &comparisonResult) { - // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 - // This corresponds to section 2. of the sample snippet in - // ArgMinMaxBaseConverter - auto cmpOp = dyn_cast(*it++); - if (cmpOp) { - if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || - currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { - return failure(); - } - } else { - return failure(); - } - - comparisonResult = cmpOp; - return success(); - } - static float getBaseReductionValue() { return -std::numeric_limits::infinity(); } - ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} + ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) { + isArgMin = false; + } }; struct ArgMinConverter : public ArgMinMaxBaseConverter { - static LogicalResult matchComparisonResult(Value currValue, Value currIndex, - Value reduceValue, - Value reduceIndex, - mlir::Block::iterator &it, - Value &comparisonResult) { - // %14 = arith.cmpf olt, %arg9, %arg11 : f32 - // This corresponds to section 2. of the sample snippet in - // ArgMinMaxBaseConverter - LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); - auto cmpOp = dyn_cast(*it++); - if (cmpOp) { - if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || - currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { - return failure(); - } - } else { - return failure(); - } - - comparisonResult = cmpOp; - return success(); - } - static float getBaseReductionValue() { return std::numeric_limits::infinity(); } - ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} + ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) { + isArgMin = true; + } }; // get_program_id and get_num_programs: @@ -2018,6 +1936,26 @@ class ReshapeConverter : public OpConversionPattern { } }; +struct GatherConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resultType = cast(op.getResult().getType()); + Value dstInit = rewriter.create( + op.getLoc(), resultType.getShape(), resultType.getElementType()); + + auto gatherOp = + rewriter.create(op.getLoc(), op.getType(), op.getSrc(), + op.getIndices(), dstInit, op.getAxis()); + + rewriter.replaceOp(op, gatherOp.getResult()); + return success(); + } +}; + class ExternElementwiseBinaryOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -2121,6 +2059,449 @@ static void populateExternElementwiseOpToMLIROps(RewritePatternSet &patterns) { ExternElementwiseUnaryOpConverter>(patterns.getContext()); } +struct HistogramOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + + auto srcTy = dyn_cast(src.getType()); + auto resTy = dyn_cast(op.getType()); + + auto flattenTensor = src; + // NOTE: Triton only support rank1. But flatten the tensor to 1D can always + // implement the histogram operation. + if (srcTy.getRank() != 1) { + SmallVector flatten_shape = {srcTy.getNumElements()}; + auto targetType = + RankedTensorType::get(flatten_shape, srcTy.getElementType()); + + auto shapeAttr = rewriter.getI64TensorAttr(flatten_shape); + auto shapeConst = rewriter.create(loc, shapeAttr); + flattenTensor = rewriter.create( + loc, targetType, flattenTensor, shapeConst); + } + + Value zero = rewriter.create( + loc, resTy, rewriter.getZeroAttr(resTy)); + Value one = rewriter.create(loc, resTy, + rewriter.getOneAttr(resTy)); + RankedTensorType cmpVecTy = + RankedTensorType::get(resTy.getShape(), srcTy.getElementType()); + + // This will be a global constant, copy to allocated memory + Value rangeVec = rewriter.create( + loc, resTy, makeRangeAttr(cmpVecTy, rewriter)); + Value empty = rewriter.create(loc, resTy.getShape(), + cmpVecTy.getElementType()); + Value allocatedRangeVec = rewriter.create( + loc, cmpVecTy, rangeVec, empty, ValueRange(), ValueRange(), + ValueRange(), SmallVector({0}), + SmallVector({resTy.getNumElements()}), + SmallVector({0})); + + Value res = zero; + + // Create loop bounds + Value lowerBound = rewriter.create(loc, 0); + Value upperBound = + rewriter.create(loc, srcTy.getNumElements()); + Value step = rewriter.create(loc, 1); + + auto forOp = rewriter.create( + loc, lowerBound, upperBound, step, ValueRange{res}, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) { + Value currentRes = iterArgs[0]; + // Extract element at current index + Value elem = + builder.create(loc, flattenTensor, iv); + SmallVector elems(resTy.getNumElements(), elem); + // Create a splat of the element + Value elemVec = builder.create(loc, elems); + + // Compare with range vector + Value mask = builder.create( + loc, arith::CmpIPredicate::eq, elemVec, allocatedRangeVec); + // Select based on mask + Value delta = + builder.create(loc, resTy, mask, one, zero); + // Add to running result + Value newRes = builder.create(loc, currentRes, delta); + // Yield the updated result + builder.create(loc, newRes); + }); + + // Replace the original op with the final histogram result + rewriter.replaceOp(op, forOp.getResults()[0]); + return success(); + } + + TypedAttr makeRangeAttr(RankedTensorType resTy, + ConversionPatternRewriter &rewriter) const { + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(32)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI32TensorAttr(range); + } else if (elemTy.isInteger(64)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI64TensorAttr(range); + } else { + llvm_unreachable( + "unsupported src elem type for histogram (expected i32 or i64)"); + } + } +}; + +// It provides accumulation function that clones operations from the +// original combine region and applies them on provided tensor. +// Also, it handles multi-dimensional cases reducing them to two +// possible options: lowering for a 1-D tensor inputs and lowering +// the operation over the leading dimension. +// +// Specialized pattern should implement lower1DInput to handle +// trailing dimension case and lowerLeadingDimension to handle the leading +// dimension case through accumulation of sub-tensors. +struct ScanOpConverter : public OpConversionPattern { +private: + mutable IRMapping invariantsMap; + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using OpAdaptor = typename triton::ScanOp::Adaptor; + + using LoweringFuncType = SmallVector (ScanOpConverter::*)( + ValueRange inputs, triton::ScanOp op, + ConversionPatternRewriter &rewriter) const; + + // Though function ptr to call lower1DInput/lowerLeadingDimension to handle + // the tensor. + SmallVector lowering(ValueRange inputs, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + uint32_t axis, LoweringFuncType handle) const { + auto loc = op.getLoc(); + auto inputType = cast(inputs[0].getType()); + auto shape = inputType.getShape(); + + SmallVector res(inputs.size()); + std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { + return rewriter.create(loc, shape, + inputType.getElementType()); + }); + + auto strides = computeStrides(shape); + // Remove trailing elems to build indices of required rank. + strides.erase(strides.begin() + axis, strides.end()); + int64_t numElems = inputType.getNumElements(); + int64_t step = strides.back(); + + for (int64_t idx = 0; idx < numElems; idx += step) { + auto indices = delinearize(idx, strides); + + SmallVector static_offsets = indices; + static_offsets.insert(static_offsets.end(), shape.size() - indices.size(), + (int64_t)0); + SmallVector static_size(indices.size(), 1); + static_size.insert(static_size.end(), shape.begin() + axis, shape.end()); + SmallVector static_stride(shape.size(), 1); + + // {1,1,..shape[axis],shape[axis+1],..shape[rank]} + SmallVector extract_shape = static_size; + + // {shape[axis],shape[axis+1],..shape[rank]} + SmallVector reshape_shape(shape.begin() + axis, shape.end()); + + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + auto extract_tensor = rewriter.create( + loc, + RankedTensorType::get(extract_shape, + inputType.getElementType()), + val, ValueRange(), ValueRange(), ValueRange(), static_offsets, + static_size, static_stride); + + auto shapeAttr = rewriter.getI64TensorAttr(reshape_shape); + auto shapeConst = + rewriter.create(loc, shapeAttr); + + // {1,1,..shape[axis],shape[axis+1],..shape[rank]} -> + // {shape[axis],shape[axis+1],..shape[rank]} + return rewriter.create( + loc, + RankedTensorType::get(reshape_shape, + inputType.getElementType()), + extract_tensor, shapeConst); + }); + + auto resElems = (this->*handle)(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + + auto targetType = + RankedTensorType::get(extract_shape, inputType.getElementType()); + + auto shapeAttr = rewriter.getI64TensorAttr(extract_shape); + auto shapeConst = rewriter.create(loc, shapeAttr); + + // {shape[axis],shape[axis+1],..shape[rank]} -> + // {1,1,..shape[axis],shape[axis+1],..shape[rank]} + auto reshaped = rewriter.create( + loc, targetType, resElems[i], shapeConst); + + res[i] = rewriter.create( + loc, res[i].getType(), reshaped, res[i], ValueRange(), ValueRange(), + ValueRange(), static_offsets, static_size, static_stride); + } + } + return res; + } + + // To handle the trailing dimension case, we extract all input vectors + // and process them through lower1DInput, then build the resulting + // vector using inserts. + LogicalResult + lowerTrailingDimension(triton::ScanOp op, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + auto inputType = cast(inputs[0].getType()); + + // 1-D input case. + if (inputType.getRank() == 1) { + auto res = lower1DInput(inputs, op, rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + uint32_t axis = op.getAxis(); + assert(axis == (inputType.getRank() - 1) && + "Expected reduction axis is the last one"); + SmallVector res = + lowering(inputs, op, rewriter, axis, &ScanOpConverter::lower1DInput); + + rewriter.replaceOp(op, res); + return success(); + } + + // In this case we either call lowerLeadingDimension to process the input + // or extract sub-vectors, call lowerLeadingDimension, and then reconstruct + // the result. + LogicalResult + lowerNonTrailingDimension(triton::ScanOp op, + ConversionPatternRewriter &rewriter) const { + + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + uint32_t axis = op.getAxis(); + if (axis == 0) { + rewriter.replaceOp(op, lowerLeadingDimension(inputs, op, rewriter)); + return success(); + } + + SmallVector res = lowering(inputs, op, rewriter, axis, + &ScanOpConverter::lowerLeadingDimension); + + rewriter.replaceOp(op, res); + return success(); + } + + // Accumulate inputs and existing accumulators into a new accumulators + // applying operations from the combine region. + SmallVector accumulate(ValueRange inputs, ValueRange acc, + Region &combineOp, + ConversionPatternRewriter &rewriter) const { + if (acc.empty()) + return inputs; + + auto type = inputs[0].getType(); + SmallVector shape; + if (isa(type)) { + auto temp = cast(type).getShape(); + shape.insert(shape.end(), temp.begin(), temp.end()); + } else { + shape.push_back(1); + } + auto &block = combineOp.getBlocks().front(); + IRMapping map; + // Map block arguments to the current inputs and accumulators. + for (unsigned i = 0; i < acc.size(); ++i) { + map.map(block.getArgument(i), acc[i]); + map.map(block.getArgument(acc.size() + i), inputs[i]); + } + for (auto &op : block.getOperations()) { + // Returned values are a new accumulator. + if (isa(op)) { + SmallVector res; + for (auto operand : op.getOperands()) { + res.push_back(map.lookup(operand)); + } + return res; + } + + // Clone operation mapping its inputs and building vector + // result types using the input shape. + OperationState newState(op.getLoc(), op.getName()); + for (auto operand : op.getOperands()) { + newState.operands.push_back( + lookupMappedValue(map, operand, shape, rewriter)); + } + for (auto ty : op.getResultTypes()) { + isa(type) + ? newState.types.push_back(RankedTensorType::get(shape, ty)) + : newState.types.push_back(ty); + } + newState.attributes = op.getAttrs(); + auto newOp = rewriter.create(newState); + + // Add new values to the map. + for (auto [oldVal, newVal] : + llvm::zip(op.getResults(), newOp->getResults())) { + map.map(oldVal, newVal); + } + } + llvm_unreachable("No return op found in scan/reduce region"); + } + + Value lookupMappedValue(IRMapping &localMap, Value val, + ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + + Value res = localMap.lookupOrNull(val); + if (!res) { + // If value is not found then it's an invariant defined in the outer + // region. We check if it has been already translated and add a splat + // operation if it hasn't. + res = invariantsMap.lookupOrNull(val); + if (!res) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfterValue(val); + res = rewriter.create( + val.getLoc(), RankedTensorType::get(shape, val.getType()), val); + invariantsMap.map(val, res); + rewriter.restoreInsertionPoint(ip); + } + } + return res; + } + + SmallVector lower1DInput(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + + auto inputType = cast(inputs[0].getType()); + auto shape = inputType.getShape(); + + SmallVector res(inputs.size()); + std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { + return rewriter.create(loc, shape, + inputType.getElementType()); + }); + + SmallVector acc; + int64_t start = reverse ? shape[0] - 1 : 0; + int64_t end = reverse ? -1 : shape[0]; + int64_t step = reverse ? -1 : 1; + for (int64_t idx = start; idx != end; idx += step) { + SmallVector inputsElem(inputs.size()); + + SmallVector idxIndex( + {rewriter.create(loc, idx)}); + + std::transform( + inputs.begin(), inputs.end(), inputsElem.begin(), [&](auto val) { + return rewriter.create(loc, val, idxIndex); + }); + + acc = accumulate(inputsElem, acc, combineOp, rewriter); + assert(acc.size() == inputs.size() && + "accumulate should return the same number of results as inputs"); + for (int i = 0; i < acc.size(); ++i) { + res[i] = + rewriter.create(loc, acc[i], res[i], idxIndex); + } + } + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + + auto inputType = cast(inputs[0].getType()); + auto shape = inputType.getShape(); + + SmallVector resTypes; + for (const auto &resTy : op.getResultTypes()) { + resTypes.push_back(RankedTensorType::get( + shape, cast(resTy).getElementType())); + } + + SmallVector res(inputs.size()); + std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { + return rewriter.create(loc, shape, + inputType.getElementType()); + }); + + SmallVector acc; + int64_t start = reverse ? shape[0] - 1 : 0; + int64_t end = reverse ? -1 : shape[0]; + int64_t step = reverse ? -1 : 1; + for (int64_t idx = start; idx != end; idx += step) { + SmallVector subInputs(inputs.size()); + + SmallVector idxVal(shape.size(), 0); + idxVal.front() = idx; + SmallVector sizeVal({1}); + sizeVal.insert(sizeVal.end(), shape.begin() + 1, shape.end()); + SmallVector strides(shape.size(), 1); + + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create( + loc, + RankedTensorType::get( + sizeVal, + cast(resTypes[0]).getElementType()), + val, ValueRange(), ValueRange(), ValueRange(), idxVal, sizeVal, + strides); + }); + + acc = accumulate(subInputs, acc, combineOp, rewriter); + + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create( + loc, resTypes[i], acc[i], res[i], ValueRange(), ValueRange(), + ValueRange(), idxVal, sizeVal, strides); + } + } + return res; + } + +public: + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rank = cast(op.getOperand(0).getType()).getRank(); + if (op.getAxis() == (rank - 1)) + return lowerTrailingDimension(op, rewriter); + + return lowerNonTrailingDimension(op, rewriter); + } +}; + } // namespace #endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td index 71899adf6..4fa35b6cd 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -48,11 +48,7 @@ def MemRefOrInt // ============================================================================= def RdmaOp : Tx81Op<"rdma", [ - AttrSizedOperandSegments, - PredOpTrait<"Constrain shape to 4d.", - CPred<"cast($_op).getShape().size() == 4">>, - PredOpTrait<"Constrain strides to 3d.", - CPred<"cast($_op).getStrides().size() == 3">> + AttrSizedOperandSegments ]> { let summary = "Copy data from global memory DDR(dram) to per thread local SPM(sram)"; @@ -63,31 +59,21 @@ def RdmaOp : Tx81Op<"rdma", [ let arguments = ( ins - MemRefOrInt:$source, // The source address in DDR - MemRefOrInt:$target, // The target address in SPM - Variadic:$shape, // NHWC shape - Variadic:$strides, // 3 dim strides - I32Attr:$fmt + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$src_shape, // src shape + Variadic:$src_strides, // src strides + Variadic:$dst_shape, // dst shape + Variadic:$dst_strides, // dst strides + I32Attr:$elem_bytes, // elem bytes + I32Attr:$fmt // elem fmt ); - let builders = [ - OpBuilder<(ins "MemRefType":$resultType, - "Value":$source, "Value":$target, - "ArrayRef":$shape, - "ArrayRef":$strides, - "IntegerAttr":$fmt - )> - ]; - let results = (outs I64:$dst); // The dest address in SPM } def WdmaOp : Tx81Op<"wdma", [ - AttrSizedOperandSegments, - PredOpTrait<"Constrain shape to 4d.", - CPred<"cast($_op).getShape().size() == 4">>, - PredOpTrait<"Constrain strides to 3d.", - CPred<"cast($_op).getStrides().size() == 3">> + AttrSizedOperandSegments ]> { let summary = "Copy data from per thread local SPM(sram) to global memory DDR(dram)"; @@ -97,22 +83,16 @@ def WdmaOp : Tx81Op<"wdma", [ let arguments = ( ins - MemRefOrInt:$source, // The source address in DDR - MemRefOrInt:$target, // The target address in SPM - Variadic:$shape, // NHWC shape - Variadic:$strides, // 3 dim strides - I32Attr:$fmt + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$src_shape, // src shape + Variadic:$src_strides, // src strides + Variadic:$dst_shape, // dst shape + Variadic:$dst_strides, // dst strides + I32Attr:$elem_bytes, // elem bytes + I32Attr:$fmt // elem fmt ); - let builders = [ - OpBuilder<(ins "MemRefType":$resultType, - "Value":$source, "Value":$target, - "ArrayRef":$shape, - "ArrayRef":$strides, - "IntegerAttr":$fmt - )> - ]; - let results = (outs I64:$dst); // The dest address in DDR } @@ -193,14 +173,14 @@ def GemmOp : Tx81Op<"gemm", []> { // FIXME: Whether need add side effect to source operands? Arg:$dst, I32ArrayAttr:$dims, // The dimensions of M, K, N - BoolAttr:$en_psum, // Enable psum? TODO: Production sum? - MemRefOrInt:$psum_addr, // The address of psum in SPM, TODO: psum? + BoolAttr:$en_psum, // Enable psum. Used as accumulate buffer + MemRefOrInt:$psum_addr, // The address of psum in SPM, Always same to output BoolAttr:$trans_src_a, // Should matrix A be transposed BoolAttr:$trans_src_b, // Should matrix B be transposed I32Attr:$batch_src_a, // The batch of matrix A I32Attr:$batch_src_b, // The batch of matrix B I32Attr:$relu_mode, // Enable LeakyRelu or normal Relu or none - BoolAttr:$en_bias, // Enable bias add + BoolAttr:$en_bias, // Enable bias add. Only support per channel(C dim), and int8 type BoolAttr:$en_neg_scale, // Enable negative axis scale MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM BoolAttr:$en_pos_scale, // Enable positive axis scale @@ -214,6 +194,56 @@ def GemmOp : Tx81Op<"gemm", []> { let results = (outs Variadic:$output); } + +// ============================================================================= +// Tsm crt ChannelNorm/Dechannelnorm +// ============================================================================= + +def ChannelNormOp : Tx81Op<"channel_norm", []> { + let summary = "Align channel dim."; + + let description = [{ + Align (N,H,W,C) to (N,cx,H,W,64) + (N,cx,H,W,c0), + which align_base = 64, cx = C/align_base + c0 = C%align_base + c0_align = get_c0_align(c0) + }]; + + let arguments = ( + ins + MemRefOrInt:$src, // Input tensor address in SPM + Arg:$dst, // Output tensor address in SPM + DenseI64ArrayAttr:$shape, // The shape info of src + // I16Attr:$cx, + I16Attr:$c0_align, + I16Attr:$dtype_size + ); + + // Output matrix C addr in SPM + let results = (outs Variadic:$output); +} + +def DechannelNormOp : Tx81Op<"dechannel_norm", []> { + let summary = "Inverse operation of channelnorm."; + + let description = [{ + Trans (N,cx,H,W,64) + (N,cx,H,W,c0) back to (N,H,W,C). + }]; + + let arguments = ( + ins + MemRefOrInt:$src, // Input tensor address in SPM + Arg:$dst, // Output tensor address in SPM + DenseI64ArrayAttr:$shape, // The shape info of src + // I16Attr:$cx, + I16Attr:$c0_align, + I16Attr:$dtype_size + ); + + // Output matrix C addr in SPM + let results = (outs Variadic:$output); +} + // ============================================================================= // 4.10. TsmArith // ============================================================================= @@ -286,7 +316,7 @@ def DivVSOp : BinaryVSOp<"divvs">; // 4.11. TsmRelation // ============================================================================= -class BoolRelationVVOp traits = []> : +class RelationVVOp traits = []> : Tx81Op { let arguments = (ins MemRefOrInt:$input0, // First input vector address @@ -298,30 +328,115 @@ class BoolRelationVVOp traits = []> : let results = (outs Variadic:$dst); } -def BoolEqualVV : BoolRelationVVOp<"boolequalvv"> { +def BoolEqualVV : RelationVVOp<"boolequalvv"> { let summary = "compare two input value, if equal, return true"; } -def BoolUnEqualVV : BoolRelationVVOp<"boolunequalvv"> { +def BoolUnEqualVV : RelationVVOp<"boolunequalvv"> { let summary = "compare two input value, if unequal, return true"; } -def BoolGreaterEqualVV : BoolRelationVVOp<"boolgreatrequalvv"> { +def BoolGreaterEqualVV : RelationVVOp<"boolgreatrequalvv"> { let summary = "compare two input value, if src0 >= src1, return true"; } -def BoolGreaterVV : BoolRelationVVOp<"boolgreatervv"> { +def BoolGreaterVV : RelationVVOp<"boolgreatervv"> { let summary = "compare two input value, if src0 > src1, return true"; } -def BoolLessEqualVV : BoolRelationVVOp<"boollessequalvv"> { +def BoolLessEqualVV : RelationVVOp<"boollessequalvv"> { let summary = "compare two input value, if src0 <= src1, return true"; } -def BoolLessThenVV : BoolRelationVVOp<"boollessthenvv"> { +def BoolLessThenVV : RelationVVOp<"boollessthenvv"> { let summary = "compare two input value, if src0 < src1, return true"; } +def EqualVV : RelationVVOp<"equalvv"> { + let summary = "compare two input value, if equal, return 1.0"; +} + +def UnEqualVV : RelationVVOp<"unequalvv"> { + let summary = "compare two input value, if unequal, return 1.0"; +} + +def GreaterEqualVV : RelationVVOp<"greatrequalvv"> { + let summary = "compare two input value, if src0 >= src1, return 1.0"; +} + +def GreaterVV : RelationVVOp<"greatervv"> { + let summary = "compare two input value, if src0 > src1, return 1.0"; +} + +def LessEqualVV : RelationVVOp<"lessequalvv"> { + let summary = "compare two input value, if src0 <= src1, return 1.0"; +} + +def LessThenVV : RelationVVOp<"lessthenvv"> { + let summary = "compare two input value, if src0 < src1, return 1.0"; +} + +class RelationVSOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + I32:$value, // Const value + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def BoolEqualVS : RelationVSOp<"boolequalvs"> { + let summary = "compare input value with ConstantOp, if equal, return true"; +} + +def BoolUnEqualVS : RelationVSOp<"boolunequalvs"> { + let summary = "compare input value with ConstantOp, if unequal, return true"; +} + +def BoolGreaterEqualVS : RelationVSOp<"boolgreatrequalvs"> { + let summary = "compare input value with ConstantOp, if src0 >= src1, return true"; +} + +def BoolGreaterVS : RelationVSOp<"boolgreatervs"> { + let summary = "compare input value with ConstantOp, if src0 > src1, return true"; +} + +def BoolLessEqualVS : RelationVSOp<"boollessequalvs"> { + let summary = "compare input value with ConstantOp, if src0 <= src1, return true"; +} + +def BoolLessThenVS : RelationVSOp<"boollessthenvs"> { + let summary = "compare input value with ConstantOp, if src0 < src1, return true"; +} + +def EqualVS : RelationVSOp<"equalvs"> { + let summary = "compare input value with ConstantOp, if equal, return 1.0"; +} + +def UnEqualVS : RelationVSOp<"unequalvs"> { + let summary = "compare input value with ConstantOp, if unequal, return 1.0"; +} + +def GreaterEqualVS : RelationVSOp<"greatrequalvs"> { + let summary = "compare input value with ConstantOp, if src0 >= src1, return 1.0"; +} + +def GreaterVS : RelationVSOp<"greatervs"> { + let summary = "compare input value with ConstantOp, if src0 > src1, return 1.0"; +} + +def LessEqualVS : RelationVSOp<"lessequalvs"> { + let summary = "compare input value with ConstantOp, if src0 <= src1, return 1.0"; +} + +def LessThenVS : RelationVSOp<"lessthenvs"> { + let summary = "compare input value with ConstantOp, if src0 < src1, return 1.0"; +} + + // ... // ============================================================================= // 4.12. TsmLogic @@ -384,11 +499,13 @@ def CosOp : UnaryOp<"cos", []> { class ActivationOp traits> : Tx81Op { let arguments = (ins - MemRefOrInt:$src, // Input vector address - UI32Attr:$elem_count, // Number of input elements - UI16Attr:$fmt // The data format of src & dst + MemRefOrInt:$input, // Input vector address + Arg:$out, // Out vector address + AnySignlessIntegerOrIndex:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst ); - let results = (outs UI64:$dst); + + let results = (outs Variadic:$dst); } def Tanh : ActivationOp<"tanh", []> { @@ -428,7 +545,7 @@ class Reduce : Tx81Op { let arguments = ( ins - AnyType:$src, // Input tensor address in SPM + MemRefOrInt:$src, // Input tensor address in SPM Arg:$dst, // Output tensor address in SPM UI32Attr:$dim, // Which dimension to be reduced I64ArrayAttr:$shape, // The shape info of src @@ -436,7 +553,7 @@ class Reduce : Tx81Op { ); // Output tensor address in SPM - let results = (outs Variadic); + let results = (outs Variadic); } def ReduceSumOp : Reduce<"reduce_sum">; @@ -461,7 +578,7 @@ When mask=0, the corresponding elements of dst remain unchanged. // The target address in SPM Arg:$target, AnySignlessIntegerOrIndex:$elem_count, // Number of elements to be copied - I32ArrayAttr:$mask, // 3 dim masks + MemRefOrInt:$mask, I32Attr:$fmt ); @@ -670,41 +787,40 @@ def Bit2FpOp : Tx81Op<"bit2fp", []> { let summary = "Convert a vector of the bitwise into the fp vector"; let arguments = (ins - UI64:$src, // Input tensor - I32Attr:$elem_count, // Number of input elements + MemRefOrInt:$src, // Input tensor + Arg:$target, + AnySignlessIntegerOrIndex:$elem_count, // Number of input elements I16Attr:$fmt // The data format of src & dst ); - let results = (outs UI64:$dst); + let results = (outs I64:$dst); } def ArgMaxOp : Tx81Op<"argmax", []> { let summary = "Return a max value inner a vector and its corresponding index"; let arguments = (ins - UI64:$src, // Input vector + MemRefOrInt:$src, // First input vector address + Arg:$value, // Address + Arg:$index, // Address I32Attr:$elem_count, // Number of input elements I16Attr:$fmt // The data format of src & dst ); - let results = (outs - AnyType:$max, // Max value inner a vector - UI64:$index // Corresponding index - ); + let results = (outs Variadic:$dst); } def ArgMinOp : Tx81Op<"argmin", []> { let summary = "Return a min value inner a vector and its corresponding index"; let arguments = (ins - UI64:$src, // Input vector + MemRefOrInt:$src, // First input vector address + Arg:$value, // Address + Arg:$index, // Address I32Attr:$elem_count, // Number of input elements I16Attr:$fmt // The data format of src & dst ); - let results = (outs - AnyType:$min, // Min value inner a vector - UI64:$index // Corresponding index - ); + let results = (outs Variadic:$dst); } def BilinearOp : Tx81Op<"bilinear", []> { @@ -769,12 +885,13 @@ def RandGenOp : Tx81Op<"randgen", []> { class TransformOp traits = []> : Tx81Op { let arguments = (ins - UI64:$src, // Input matrix or tensor address - I32ArrayAttr:$src_shape, // Input shape - I32ArrayAttr:$dst_shape, // Output shape + MemRefOrInt:$source, // Input tensor + MemRefOrInt:$target, // Output tensor + DenseI32ArrayAttr:$src_shape, // Input shape + DenseI32ArrayAttr:$dst_shape, // Output shape I16Attr:$fmt // The data format of src & dst ); - let results = (outs UI64:$dst); + let results = (outs I64:$dst); } def Mirror : TransformOp<"mirror", []> { @@ -851,14 +968,31 @@ def GatherScatter : Tx81Op<"gatherscatter", []> { let summary = "Transfer data in strides and iterations"; let arguments = (ins - UI64:$src, // Input tensor - I32Attr:$size, // Transfer data size in bytes - I32ArrayAttr:$src_strides, // 3 dim strides for input - I32ArrayAttr:$src_iterations, // 3 dim iterations for input - I32ArrayAttr:$dst_strides, // 3 dim strides for output - I32ArrayAttr:$dst_iterations // 3 dim iterations for output + MemRefOrInt:$source, // The source + MemRefOrInt:$target, // The target + I32Attr:$bytes, // Inner loop data size in bytes + I32Attr:$src_strideN, + I32Attr:$src_strideH, + I32Attr:$src_strideW, + I32Attr:$src_iterN, + I32Attr:$src_iterH, + I32Attr:$src_iterW, + I32Attr:$dst_strideN, + I32Attr:$dst_strideH, + I32Attr:$dst_strideW, + I32Attr:$dst_iterN, + I32Attr:$dst_iterH, + I32Attr:$dst_iterW ); - let results = (outs UI64:$dst); + let results = (outs I64:$dst); +} + +def BarrierOp : Tx81Op<"barrier"> { + let summary = "Synchronizes all work items"; + let description = [{ + The "barrier" op synchronizes all work items. + }]; + let assemblyFormat = "attr-dict"; } #endif // TSINGMICRO_TX81_OPS diff --git a/third_party/tsingmicro/include/utils/FusionHelper.h b/third_party/tsingmicro/include/utils/FusionHelper.h new file mode 100644 index 000000000..23a68cd84 --- /dev/null +++ b/third_party/tsingmicro/include/utils/FusionHelper.h @@ -0,0 +1,182 @@ +#ifndef TRITON_FUSION_PATTERNS +#define TRITON_FUSION_PATTERNS + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include +#include +#include + +using namespace mlir; + +namespace { + +//===--------------------------- Match ArgMinMax --------------------------===// + +// We're looking for an op that looks like this: +// +// %9:2 = "tt.reduce"(%8, %3) <{axis = 0 : i32}> ({ +// ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): +// ------------------------------------------------- +// `matchTieBreakValue` | +// %11 = arith.cmpf oeq, %arg9, %arg11 : f32 | +// %12 = arith.cmpi slt, %arg10, %arg12 : i32 | 1. +// %13 = arith.andi %11, %12 : i1 | +// ------------------------------------------------- |-> `matchShouldUpdate` +// `matchUpdateCondition` | +// %14 = arith.cmpf ogt, %arg9, %arg11 : f32 | 2. +// ------------------------------------------------- | +// %15 = arith.ori %14, %13 : i1 | +// ------------------------------------------------- +// %16 = arith.select %15, %arg9, %arg11 : f32 +// %17 = arith.select %15, %arg10, %arg12 : i32 + +static LogicalResult matchTieBreakResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &tileBreakValue) { + // Match the following (section 1. of the above) + // + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 + // %13 = arith.andi %11, %12 : i1 + // + // which is equivalent to the following python code + // + // tie = value1 == value2 and index1 < index2 + + // matching: %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + auto eqCmpOp = dyn_cast(*it++); + if (!eqCmpOp || eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ || + currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { + return failure(); + } + + // matching: %12 = arith.cmpi slt, %arg10, %arg12 : i32 + auto sltCmpOp = dyn_cast(*it++); + if (!sltCmpOp || sltCmpOp.getPredicate() != arith::CmpIPredicate::slt || + currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { + return failure(); + } + + // matching: %13 = arith.andi %11, %12 : i1 + auto andOp = dyn_cast(*it++); + if (!andOp || andOp.getLhs() != eqCmpOp || andOp.getRhs() != sltCmpOp) { + return failure(); + } + + tileBreakValue = andOp; + return success(); +} + +static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult, + bool isArgMin) { + // %14 = arith.cmpf olt(ogt), %arg9, %arg11 : f32 + auto cmpOp = dyn_cast(*it++); + if (!cmpOp) { + return failure(); + } + + auto predicate = + isArgMin ? arith::CmpFPredicate::OLT : arith::CmpFPredicate::OGT; + if (cmpOp.getPredicate() != predicate || currValue != cmpOp.getLhs() || + reduceValue != cmpOp.getRhs()) { + return failure(); + } + + comparisonResult = cmpOp; + return success(); +} + +static LogicalResult +matchShouldUpdateValue(Value currValue, Value currIndex, Value reduceValue, + Value reduceIndex, mlir::Block::iterator &it, + Value &shouldUpdate, bool isArgMin) { + Value tieResult; + if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, reduceIndex, + it, tieResult))) { + return failure(); + } + + Value comparisonResult; + if (failed(matchComparisonResult(currValue, currIndex, reduceValue, + reduceIndex, it, comparisonResult, + isArgMin))) { + return failure(); + } + + // matching: %15 = arith.ori %14, %13 : i1 + auto orOp = dyn_cast(*it++); + if (!orOp || orOp.getLhs() != comparisonResult || + orOp.getRhs() != tieResult) { + return failure(); + } + + shouldUpdate = orOp; + return success(); +} + +LogicalResult matchSelect(mlir::Block::iterator &opsIt, Value curr, + Value reduce, Value shouldUpdate, Value &result) { + auto selectOp = dyn_cast(*opsIt++); + if (!selectOp) { + return failure(); + } + + if (selectOp.getCondition() != shouldUpdate || + curr != selectOp.getTrueValue() || reduce != selectOp.getFalseValue()) { + return failure(); + } + + result = selectOp; + + return success(); +} + +LogicalResult matchArgMinMax(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &opsIt, Value &indexResult, + Value &valueResult, bool isArgMin) { + Value shouldUpdate; + if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, + reduceIndex, opsIt, shouldUpdate, + isArgMin))) { + return failure(); + } + + // matching: %16 = arith.select %15, %arg9, %arg11 : f32 + Value valueSelectOp; + if (failed(matchSelect(opsIt, currValue, reduceValue, shouldUpdate, + valueSelectOp))) { + return failure(); + } + + // matching:%17 = arith.select %15, %arg10, %arg12 : i32 + Value indexSelectOp; + if (failed(matchSelect(opsIt, currIndex, reduceIndex, shouldUpdate, + indexSelectOp))) { + return failure(); + } + + indexResult = indexSelectOp; + valueResult = valueSelectOp; + + return success(); +} + +} // namespace + +#endif diff --git a/third_party/tsingmicro/include/utils/utils.h b/third_party/tsingmicro/include/utils/utils.h new file mode 100644 index 000000000..6829a2800 --- /dev/null +++ b/third_party/tsingmicro/include/utils/utils.h @@ -0,0 +1,27 @@ +//===------------------- utils.h ------------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Utility functions for ztc conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_UTILS_H +#define ZTC_CONVERSION_UTILS_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Types.h" // Include the header for Type + +using namespace mlir; + +namespace mlir::triton::utils { +Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, + StringRef name, Type resultType, + ArrayRef argumentTypes); +} // namespace mlir::triton::utils + +#endif // ZTC_CONVERSION_UTILS_H diff --git a/third_party/tsingmicro/lib/CMakeLists.txt b/third_party/tsingmicro/lib/CMakeLists.txt index eff85b208..b22aeb188 100644 --- a/third_party/tsingmicro/lib/CMakeLists.txt +++ b/third_party/tsingmicro/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(AnalysisStructured) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Utils) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt index c4d27c2b3..4f092c10b 100644 --- a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt @@ -20,4 +20,6 @@ add_triton_library(CoreDialectsToMK MLIRTensorDialect MLIRTransforms MLIRSupport + + LinalgToMagicKernel ) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp index a9ffab977..bb55cddee 100644 --- a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp @@ -24,6 +24,7 @@ using namespace triton; #define GEN_PASS_CLASSES #include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h.inc" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" namespace { @@ -35,7 +36,7 @@ class CoreDialectsToMKPass : public CoreDialectsToMKBase { .insert(); + memref::MemRefDialect, mk::MagicKernelDialect>(); } void runOnOperation() override { diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp index a3970cf81..75eac5c06 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -53,4 +53,5 @@ void mlir::triton::populateLinalgToMKCanonicalizationPatterns( void mlir::triton::populateLinalgToMKConversionPatterns( RewritePatternSet &patterns) { // patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp index eaba1f34f..1646c1f3d 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp @@ -36,28 +36,12 @@ class LinalgToMKPass : public triton::impl::LinalgToMKBase { using LinalgToMKBase::LinalgToMKBase; public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { auto moduleOp = getOperation(); RewritePatternSet patterns(&getContext()); - ConversionTarget target(getContext()); - - // TODO: Enable this when all conversion pattern has been implemented. - // target.addIllegalDialect(); - - target.addLegalDialect(); - - target.addLegalOp(); triton::populateLinalgToMKConversionPatterns(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { signalPassFailure(); } } diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt index b7e951a04..f40b671ef 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_library(MKToTx81 LINK_LIBS PUBLIC MLIRArithDialect + MLIRLinalgDialect MLIRDialectUtils MLIRIR MLIRPass diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp index 3b51faf43..21ecb5b92 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -25,10 +26,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "utils/FusionHelper.h" #include "llvm/ADT/TypeSwitch.h" -#define DEBUG_TYPE "mk-to-tx81" +// #define DEBUG_TYPE "mk-to-tx81" using namespace mlir; using namespace tx; @@ -86,6 +89,12 @@ Data_Format getFormatCode(MemRefType type) { return Fmt_FP32; } +bool isSupportedType(MemRefType type) { + auto elemType = type.getElementType(); + return elemType.isF32() || elemType.isF16() || elemType.isBF16() || + elemType.isInteger(8); +} + // Helper function to extract shape from tensor type SmallVector getShapeFromTensorType(TensorType type) { SmallVector shape; @@ -97,10 +106,10 @@ SmallVector getShapeFromTensorType(TensorType type) { // Helper function to extract dimensions from memref or tensor type SmallVector getDimsFromType(Type type) { SmallVector dims; - if (auto memrefType = mlir::dyn_cast(type)) { + if (auto memrefType = dyn_cast(type)) { for (auto dim : memrefType.getShape()) dims.push_back(static_cast(dim)); - } else if (auto tensorType = mlir::dyn_cast(type)) { + } else if (auto tensorType = dyn_cast(type)) { for (auto dim : tensorType.getShape()) dims.push_back(static_cast(dim)); } @@ -122,7 +131,7 @@ static Value createAddressFromMemref(ConversionPatternRewriter &rewriter, rewriter.create(loc, memref); Value indexBasePtr = rewriter.create( loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); - auto elemType = mlir::cast(memref.getType()).getElementType(); + auto elemType = dyn_cast(memref.getType()).getElementType(); Value elemByte = rewriter.create(loc, getElemByte(elemType)); Value offset = stridedMetadata.getOffset(); @@ -135,14 +144,14 @@ static Value createAddressFromMemref(ConversionPatternRewriter &rewriter, return i64SPMPtr; } -static std::tuple +static std::tuple, SmallVector> createMetadata(ConversionPatternRewriter &rewriter, Location loc, Value operand) { auto stridedMetadata = rewriter.create(loc, operand); Value indexBasePtr = rewriter.create( loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); - auto elemType = mlir::cast(operand.getType()).getElementType(); + auto elemType = dyn_cast(operand.getType()).getElementType(); Value elemByte = rewriter.create(loc, getElemByte(elemType)); Value offset = stridedMetadata.getOffset(); @@ -184,7 +193,6 @@ padStridesToNHWC(ConversionPatternRewriter &rewriter, Location loc, for (auto dim : strides) { nhwcStrides.push_back(dim); } - nhwcStrides.pop_back(); return nhwcStrides; } @@ -210,6 +218,22 @@ template llvm::SmallVector getRegionOps(T linalgOp) { [](Operation &op) { return &op; }); } +static Data_Format getFormatFromValueType(MemRefType valueType) { + auto elemType = valueType.getElementType(); + auto bitWidth = elemType.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 8: + return Fmt_INT8; + case 16: + return Fmt_FP16; + case 32: + return Fmt_FP32; + default: + llvm_unreachable("Unsupported bit width\n"); + } + return Fmt_FP32; +} + // Convert integer type to float type for CGRA instruction // Return the convert float type format code // TODO: Directly convert memref type? @@ -278,6 +302,41 @@ Value insertRestoreTypeOp(Value valuePtr, MemRefType valueType, Value elemCount, return newValue; } +SmallVector reshapeReduceShapeTo4d(ArrayRef inputShape, + int64_t dim) { + + auto rank = inputShape.size(); + SmallVector newShape; + int64_t leftDimsElement = 1; + int64_t rightDimsElement = 1; + + for (int i = 0; i < dim; i++) + leftDimsElement *= inputShape[i]; + + if (dim == inputShape.size() - 1) + return {1, 1, leftDimsElement, inputShape[dim]}; + + for (int i = dim + 1; i < rank; i++) + rightDimsElement *= inputShape[i]; + + newShape = {1, leftDimsElement, inputShape[dim], rightDimsElement}; // NHWC + return newShape; +} + +uint64_t next_power_of_two_64(uint64_t x) { + if (x == 0) { + return 1; + } + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x |= x >> 32; + return x + 1; +} + class MemoryCopyConvertPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -293,13 +352,34 @@ class MemoryCopyConvertPattern : public OpConversionPattern { return false; } + // View int64 as int8 + void reshapeInt64ToInt8(SmallVector &srcSizes, + SmallVector &srcStrides, + ConversionPatternRewriter &rewriter, + Operation *op) const { + // Shape as int8, multiply bitsize to last dimension + auto lastDim = srcSizes.back(); + srcSizes.back() = rewriter.create( + op->getLoc(), lastDim.getType(), lastDim, + rewriter.create(op->getLoc(), 8)); + + // Stride as int8, multiply bitsize to each stride + std::transform(srcStrides.begin(), srcStrides.end(), srcStrides.begin(), + [&](Value size) { + return rewriter.create( + op->getLoc(), size.getType(), size, + rewriter.create( + op->getLoc(), 8)); // 8 bytes for int64_t + }); + } + LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(op->hasAttr("srcSpm") && op->hasAttr("dstSpm") && "Can't get memory space attribute\n"); - bool isSrcSPM = mlir::cast(op->getAttr("srcSpm")).getInt(); - bool isDstSPM = mlir::cast(op->getAttr("dstSpm")).getInt(); + bool isSrcSPM = op->getAttrOfType("srcSpm").getInt(); + bool isDstSPM = op->getAttrOfType("dstSpm").getInt(); // DDR to DDR if (!isSrcSPM && !isDstSPM) @@ -312,40 +392,73 @@ class MemoryCopyConvertPattern : public OpConversionPattern { createMetadata(rewriter, op->getLoc(), adaptor.getTarget()); auto inputType = dyn_cast(op.getSource().getType()); + // For memory operations, 64 bit op can work as int8_t mode + if (inputType.getElementTypeBitWidth() == 64) { + // Need re-calculate sizes and strides, and convert type to int8_t + reshapeInt64ToInt8(srcSizes, srcStrides, rewriter, op); + reshapeInt64ToInt8(dstSizes, dstStrides, rewriter, op); + + // Convert type to int8_t + SmallVector shape(inputType.getShape().begin(), + inputType.getShape().end()); + shape.back() *= sizeof(int64_t); + + inputType = + MemRefType::get(shape, rewriter.getI8Type(), inputType.getLayout(), + inputType.getMemorySpace()); + } + + auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); + auto srcFmt = getFormatFromValueType(inputType); + // SPM to SPM if (isSrcSPM && isDstSPM) { - // FIXME: Only support 1d for now, take sizes[0] as elemCount. - auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); - // WORKAROUND: Assume no mask. + srcFmt = insertConvertTypeOp(srcPtr, inputType, elemCount, rewriter, + op->getLoc()); + insertConvertTypeOp(dstPtr, inputType, elemCount, rewriter, op->getLoc()); auto constValue = rewriter.create( op.getLoc(), 0, rewriter.getI32Type()); - rewriter.create( - op->getLoc(), rewriter.getI64Type(), srcPtr, constValue, dstPtr, - elemCount, // Element count - rewriter.getI16IntegerAttr(0), // Round mode - rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format + rewriter.create(op->getLoc(), rewriter.getI64Type(), srcPtr, + constValue, dstPtr, + elemCount, // Element count + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(srcFmt) // Format ); - } else if (isDstSPM) { - auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), srcSizes); - auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), srcStrides); + insertRestoreTypeOp(srcPtr, inputType, elemCount, rewriter, op->getLoc()); + insertRestoreTypeOp(dstPtr, inputType, elemCount, rewriter, op->getLoc()); + rewriter.eraseOp(op); + return success(); + } + auto srcShape4d = padSizesToNHWC(rewriter, op->getLoc(), srcSizes); + auto srcStrides4d = padStridesToNHWC(rewriter, op->getLoc(), srcStrides); + auto dstShape4d = padSizesToNHWC(rewriter, op->getLoc(), dstSizes); + auto dstStrides4d = padStridesToNHWC(rewriter, op->getLoc(), dstStrides); + + int bitWidth = inputType.getElementType().getIntOrFloatBitWidth(); + int elemBytes = bitWidth / 8; + + if (isDstSPM) { auto rdmaOp = rewriter.create( op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, - nhwcShape, // NHWC shape - nhwcStrides, // NHWC stride - rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format + srcShape4d, // src shape + srcStrides4d, // src stride + dstShape4d, // dst shape + dstStrides4d, // dst stride + rewriter.getI32IntegerAttr(elemBytes), // elem bytes + rewriter.getI32IntegerAttr(srcFmt) // Format ); } else { - auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), dstSizes); - auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), dstSizes); - auto wdmaOp = rewriter.create( op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, - nhwcShape, // NHWC shape - nhwcStrides, // NHWC stride - rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format + srcShape4d, // src shape + srcStrides4d, // src stride + dstShape4d, // dst shape + dstStrides4d, // dst stride + rewriter.getI32IntegerAttr(elemBytes), // elem bytes + rewriter.getI32IntegerAttr(srcFmt) // Format ); } @@ -372,19 +485,17 @@ class LinalgFillOpConversion : public OpConversionPattern { createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); auto inputType = op.getInputs()[0].getType(); auto bitWidth = op.getInputs()[0].getType().getIntOrFloatBitWidth(); - assert(bitWidth == 16 || - bitWidth == 32 && "Only support 16/32 fill value\n"); + + if (bitWidth != 16 && bitWidth != 32) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, " op not yet supported"); + rewriter.eraseOp(op); + return success(); + } // AddVS value need has fmt with input fmt and only support float type Data_Format fmt = bitWidth == 16 ? Fmt_FP16 : Fmt_FP32; - if (inputType.isInteger()) { - auto floatType = - bitWidth == 16 ? rewriter.getF16Type() : rewriter.getF32Type(); - fillValue = - rewriter.create(op.getLoc(), floatType, fillValue); - } - auto bitcastType = bitWidth == 16 ? rewriter.getI16Type() : rewriter.getI32Type(); fillValue = @@ -419,12 +530,150 @@ class LinalgFillOpConversion : public OpConversionPattern { } }; +class TransposeOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + convertToGatherScatter(linalg::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto perm = op.getPermutation(); + auto rank = perm.size(); + + auto src = op.getInput(); + auto dst = op.getInit(); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + + // Get NHWC shape + SmallVector srcShape(srcType.getShape()); + SmallVector dstShape(dstType.getShape()); + SmallVector perm4d(perm.begin(), perm.end()); + while (srcShape.size() < 4) { + srcShape.push_back(1); + dstShape.push_back(1); + perm4d.push_back(perm4d.size()); + } + + // Get inner bytes + int32_t elemCount = srcShape[3]; + auto elemType = srcType.getElementType(); + auto bitWidth = elemType.getIntOrFloatBitWidth(); + auto byte = bitWidth / 8; + auto bytes = elemCount * byte; + + // Get strides + SmallVector srcStride(srcShape.size()); + srcStride[srcShape.size() - 1] = byte; + for (int i = srcShape.size() - 2; i >= 0; --i) { + srcStride[i] = srcStride[i + 1] * srcShape[i + 1]; + } + SmallVector dstStride(dstShape.size()); + for (int i = 0; i < dstShape.size(); i++) { + dstStride[i] = srcStride[perm4d[i]]; + } + + auto srcPtr = createAddressFromMemref(rewriter, op->getLoc(), src); + auto dstPtr = createAddressFromMemref(rewriter, op->getLoc(), dst); + + auto newOp = rewriter.create( + op->getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, bytes, + srcStride[0], srcStride[1], srcStride[2], srcShape[0], srcShape[1], + srcShape[2], dstStride[0], dstStride[1], dstStride[2], dstShape[0], + dstShape[1], dstShape[2]); + + rewriter.eraseOp(op); + return success(); + } + + LogicalResult convertToTranspose(linalg::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = op.getInput(); + auto dst = op.getInit(); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + int32_t dim0 = srcType.getShape()[0]; + int32_t dim1 = srcType.getShape()[1]; + SmallVector srcShape({1, dim0, dim1, 1}); + SmallVector dstShape({1, dim1, dim0, 1}); + + auto srcPtr = createAddressFromMemref(rewriter, op->getLoc(), src); + auto dstPtr = createAddressFromMemref(rewriter, op->getLoc(), dst); + Data_Format fmt = getFormatCode(srcType); + + auto newOp = + rewriter.create(op->getLoc(), rewriter.getI64Type(), + srcPtr, dstPtr, srcShape, dstShape, fmt); + + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult transposeChannel(linalg::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = op.getInput(); + auto dst = op.getInit(); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + SmallVector srcShape(srcType.getShape().begin(), + srcType.getShape().end()); + SmallVector dstShape(dstType.getShape().begin(), + dstType.getShape().end()); + + auto srcPtr = createAddressFromMemref(rewriter, op->getLoc(), src); + auto dstPtr = createAddressFromMemref(rewriter, op->getLoc(), dst); + Data_Format fmt = getFormatCode(srcType); + + auto newOp = + rewriter.create(op->getLoc(), rewriter.getI64Type(), srcPtr, + dstPtr, srcShape, dstShape, fmt); + + rewriter.eraseOp(op); + return success(); + } + + LogicalResult + matchAndRewrite(linalg::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto perm = op.getPermutation(); + auto rank = perm.size(); + + // FIXME: tx.transpose/tx.gather_scatter need to cx align, + // convert 2d transpose to loops default now. + // if (rank == 2) + // return convertToGatherScatter(op, adaptor, rewriter); + + if (rank == 3) + return convertToGatherScatter(op, adaptor, rewriter); + + if (rank == 4 && perm[3] == 3) { + return convertToGatherScatter(op, adaptor, rewriter); + } + + if (rank == 4 && perm == ArrayRef({0, 2, 3, 1})) { + return transposeChannel(op, adaptor, rewriter); + } + + if (rank == 4 && perm == ArrayRef({0, 3, 1, 2})) { + return transposeChannel(op, adaptor, rewriter); + } + + // Default handling of remaining cases. + // TODO: Convert higher rank to tx. + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, " op not yet supported"); + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // mk.dot to tx.gemm Conversion Pattern //===----------------------------------------------------------------------===// class MKDotToTx81GemmOpConversion - : public OpConversionPattern { + : public OpConversionPattern { void fp32ToTF32(ConversionPatternRewriter &rewriter, Location loc, ValueRange sizes, Value spmAddr) const { @@ -440,24 +689,70 @@ class MKDotToTx81GemmOpConversion ); } + Value createChannelNorm(Location loc, Value op, + ConversionPatternRewriter &rewriter) const { + auto memType = cast(op.getType()); + auto shape = memType.getShape(); + int bitWidth = memType.getElementType().getIntOrFloatBitWidth(); + + int alignBase = bitWidth == 8 ? 128 : 64; + + int c = shape.back(); + // Has been cx aligned + bool noNeedChannelNorm = + (c >= 4 && c <= alignBase && c == next_power_of_two_64(c)); + if (noNeedChannelNorm) { + return op; + } + + int cx = c / alignBase; + int c0 = c % alignBase; + int alignedC0 = c0 ? next_power_of_two_64(c0) : 0; + if (c0 < 4 && c0 > 0) { + // If c0 is not zero, we need to align it to 4 + alignedC0 = 4; + } + int alignedC = cx * alignBase + alignedC0; + SmallVector alignedShape(shape.begin(), shape.end()); + alignedShape.back() = alignedC; + + auto alignedMemType = + MemRefType::get(alignedShape, memType.getElementType()); + auto alignedAlloc = rewriter.create(loc, alignedMemType); + + auto srcPtr = createAddressFromMemref(rewriter, loc, op); + auto dstPtr = createAddressFromMemref(rewriter, loc, alignedAlloc); + + SmallVector shape4D = + reshapeReduceShapeTo4d(shape, shape.size() - 1); + auto channelNorm = rewriter.create( + loc, TypeRange({}), srcPtr, dstPtr, + rewriter.getDenseI64ArrayAttr(shape4D), + rewriter.getI16IntegerAttr(alignedC0), + rewriter.getI16IntegerAttr(bitWidth)); + return alignedAlloc; + } + public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(mlir::mk::DotOp op, OpAdaptor adaptor, + matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Extract dimensions from tensor types - MemRefType aTensorType = mlir::cast(op.getA().getType()); - MemRefType bTensorType = mlir::cast(op.getB().getType()); + MemRefType aTensorType = cast(op->getOperand(0).getType()); + MemRefType bTensorType = cast(op->getOperand(1).getType()); assert(aTensorType.getElementType() == bTensorType.getElementType() && "a and b must have the same element type"); - MemRefType zeroTensorType = - mlir::cast(op.getZeroes().getType()); + MemRefType dstType = cast(op.getOutputs()[0].getType()); Data_Format srcFmt = getFormatCode(aTensorType); - Data_Format dstFmt = getFormatCode(zeroTensorType); + Data_Format dstFmt = getFormatCode(dstType); // Get converted operands - auto loc = op.getLoc(); + auto loc = op->getLoc(); + + auto a = adaptor.getInputs()[0]; + auto b = adaptor.getInputs()[1]; auto aShape = aTensorType.getShape(); auto bShape = bTensorType.getShape(); @@ -465,48 +760,58 @@ class MKDotToTx81GemmOpConversion // Matrix dimensions M, K, N for GEMM int32_t M = aShape[0]; int32_t K = aShape[1]; - int32_t N = bShape[1]; + // Notice: operand is (N, K) now. + int32_t N = bShape[0]; // Create dimensions array attribute [M, K, N] auto dims = rewriter.getI32ArrayAttr({M, K, N}); + auto alignedA = createChannelNorm(loc, a, rewriter); + auto alignedB = createChannelNorm(loc, b, rewriter); + // Get operand ptr - auto [aPtr, aSizes, aStrides] = - createMetadata(rewriter, op->getLoc(), adaptor.getA()); - auto [bPtr, bSizes, bStrides] = - createMetadata(rewriter, op->getLoc(), adaptor.getB()); - auto [cPtr, cSizes, cStrides] = - createMetadata(rewriter, op->getLoc(), adaptor.getC()); + auto [aPtr, aSizes, aStrides] = createMetadata(rewriter, loc, alignedA); + auto [bPtr, bSizes, bStrides] = createMetadata(rewriter, loc, alignedB); + + auto [dstPtr, dstSizes, dstStrides] = + createMetadata(rewriter, loc, adaptor.getOutputs()[0]); + // Assume input type is same. Tx neural engine not support fp32 for input if (aTensorType.getElementType().isF32()) { srcFmt = Data_Format::Fmt_TF32; - fp32ToTF32(rewriter, op->getLoc(), aSizes, aPtr); - fp32ToTF32(rewriter, op->getLoc(), bSizes, bPtr); - fp32ToTF32(rewriter, op->getLoc(), cSizes, cPtr); + fp32ToTF32(rewriter, loc, aSizes, aPtr); + fp32ToTF32(rewriter, loc, bSizes, bPtr); } - auto dst = createAddressFromMemref(rewriter, loc, adaptor.getZeroes()); + auto zero = + rewriter.create(loc, 0, rewriter.getI64Type()); - auto zero = rewriter.create(op.getLoc(), 0, - rewriter.getI64Type()); + // Check if N is a power of 2 and greater than 4 + if ((N & (N - 1)) != 0) { // Check if N is not a power of 2 + return rewriter.notifyMatchFailure(op, "N must be a power of 2"); + } + if (N < 4) { + return rewriter.notifyMatchFailure(op, "N must be greater than 4"); + } // Create GemmOp + // TODO: Support bias when input is int8 rewriter.create( - op.getLoc(), rewriter.getI64Type(), + loc, rewriter.getI64Type(), aPtr, // src_a (Matrix A in SPM) bPtr, // src_b (Matrix B in SPM) - cPtr, // src_bias (optional accumulation) - dst, // dst, + dstPtr, // src_bias. Unused for now. + dstPtr, // dst, dims, // dimensions [M,K,N] - rewriter.getBoolAttr(false), // en_psum - dst, // WORKAROUND: psum_addr (using dst buffer) + rewriter.getBoolAttr(false), // en_psum. Used as accumulate buffer + dstPtr, // The address of psum in SPM, Always same to output rewriter.getBoolAttr(false), // trans_src_a // NOTE: (N, K) is thought not trans in hardware - rewriter.getBoolAttr(true), // trans_src_b. + rewriter.getBoolAttr(false), // trans_src_b. rewriter.getI32IntegerAttr(1), // batch_src_a rewriter.getI32IntegerAttr(1), // batch_src_b rewriter.getI32IntegerAttr(ActFuncMode::None), // relu_mode. - rewriter.getBoolAttr(op.getC() != nullptr), // en_bias + rewriter.getBoolAttr(false), // en_bias rewriter.getBoolAttr(false), // en_neg_scale zero, // src_neg_scale rewriter.getBoolAttr(false), // en_pos_scale @@ -514,6 +819,18 @@ class MKDotToTx81GemmOpConversion rewriter.getI32IntegerAttr(srcFmt), // src_fmt rewriter.getI32IntegerAttr(dstFmt) // dst_fmt ); + + // DechannelNorm dst + int bitWidth = dstType.getElementType().getIntOrFloatBitWidth(); + int alignBase = bitWidth == 32 ? 64 : 128; + if (N > alignBase) { + auto dechannelNorm = rewriter.create( + loc, TypeRange({}), dstPtr, dstPtr, + rewriter.getDenseI64ArrayAttr({1, 1, M, N}), + rewriter.getI16IntegerAttr(0) /*alignedC0*/, + rewriter.getI16IntegerAttr(bitWidth)); + } + // Op has no result value rewriter.eraseOp(op); @@ -521,6 +838,174 @@ class MKDotToTx81GemmOpConversion } }; +class GatherConvertPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::mk::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto indices = adaptor.getIndices(); + auto indicesType = cast(indices.getType()); + auto shape = indicesType.getShape(); + + auto axis = op.getAxis(); + + int64_t numElems = indicesType.getNumElements(); + auto strides = computeStrides(shape); + for (int64_t idx = 0; idx < numElems; idx += 1) { + auto tensorIdx = delinearize(idx, strides); + + SmallVector idxIndex(tensorIdx.size()); + std::transform(tensorIdx.begin(), tensorIdx.end(), idxIndex.begin(), + [&](auto val) { + return rewriter.create(loc, val); + }); + // Read the index value from indices tensor + Value indexValue = + rewriter.create(loc, indices, idxIndex); + + // Read value from source using computed indices + SmallVector inputIndex = idxIndex; + assert(axis < inputIndex.size() && axis >= 0 && + "Axis index out of bounds"); + inputIndex[axis] = rewriter.create( + loc, rewriter.getIndexType(), indexValue); + + Value gatheredValue = + rewriter.create(loc, adaptor.getSrc(), inputIndex); + + // Write value to destination + rewriter.create(loc, gatheredValue, adaptor.getDst(), + idxIndex); + } + + rewriter.eraseOp(op); + + return success(); + } +}; + +class MKSigmoidToTx81SigmoidOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::mk::SigmoidOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto [input, sizes, strides] = + createMetadata(rewriter, loc, adaptor.getSrc()); + auto [dst, dstSizes, dstStrides] = + createMetadata(rewriter, loc, adaptor.getZeroes()); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + // Tx neural engine not support fp32 for input + auto inputType = dyn_cast(op.getSrc().getType()); + Data_Format srcFmt = getFormatCode(inputType); + + rewriter.create(loc, rewriter.getI64Type(), input, dst, + elemCount, rewriter.getI16IntegerAttr(srcFmt)); + rewriter.eraseOp(op); + + return success(); + } +}; + +template +class ArgMinMaxBaseConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + bool isArgMin; + + ArgMinMaxBaseConversion(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(linalg::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getBody()->getNumArguments() != 4) { + return failure(); + } + + auto block = op.getBody(); + auto ops = block->without_terminator(); + + Value currValue = block->getArgument(0); + Value currIndex = block->getArgument(1); + Value reduceValue = block->getArgument(2); + Value reduceIndex = block->getArgument(3); + + auto opsIter = ops.begin(); + Value indexSelectOp, valueSelectOp; + if (failed(matchArgMinMax(currValue, currIndex, reduceValue, reduceIndex, + opsIter, indexSelectOp, valueSelectOp, + isArgMin))) { + return failure(); + } + + // matching: linalg.yield %18, %19 : f32, i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIter << "\n"); + auto termOp = dyn_cast(*opsIter++); + if (termOp && termOp == block->getTerminator()) { + auto opnds = termOp.getOperands(); + if (opnds != ArrayRef{valueSelectOp, indexSelectOp}) { + return failure(); + } + } else { + return failure(); + } + + // Rewrite with tx81 operation + auto loc = op.getLoc(); + auto dims = op.getDimensions(); + if (dims.size() != 1) + return rewriter.notifyMatchFailure(op, "Only support one dim reduce."); + auto dim = dims[0]; + + auto input = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInputs()[0]); + auto value = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInits()[0]); + auto index = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInits()[1]); + auto inputType = dyn_cast(op.getInputs()[0].getType()); + auto valueType = dyn_cast(adaptor.getInits()[0].getType()); + + // TODO: Support any rank + auto inputShape = inputType.getShape(); + if (inputShape.size() > 1) + return rewriter.notifyMatchFailure(op, "Rank > 1 unsupported yet."); + + int64_t inputSize = inputShape.empty() ? 1 : inputShape[0]; + auto tx81Op = rewriter.create( + op->getLoc(), TypeRange{}, input, + // TODO: get output value and index + value, index, rewriter.getI32IntegerAttr(inputSize), + rewriter.getI16IntegerAttr(getFormatCode(valueType))); + + rewriter.replaceOp(op, tx81Op); + + return success(); + } +}; + +struct ArgMinConversion : public ArgMinMaxBaseConversion { + ArgMinConversion(MLIRContext *context) : ArgMinMaxBaseConversion(context) { + isArgMin = true; + } +}; + +struct ArgMaxConversion : public ArgMinMaxBaseConversion { + ArgMaxConversion(MLIRContext *context) : ArgMinMaxBaseConversion(context) { + isArgMin = false; + } +}; + struct ElementwiseConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -566,8 +1051,17 @@ struct ElementwiseConversion : public OpConversionPattern { // Data format after conversion Data_Format srcFmt = insertConvertTypeOp(input0, inputType, elemCount, rewriter, loc); - insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); - insertConvertTypeOp(output, inputType, elemCount, rewriter, loc); + if (adaptor.getInputs()[0] != adaptor.getInputs()[1]) { + // If input0 and input1 are not the same, we need to convert input1 type + insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); + } + + if (adaptor.getInputs()[0] != adaptor.getOutputs()[0] && + adaptor.getInputs()[1] != adaptor.getOutputs()[0]) { + // If input and output are not the same, we need to convert output type + Data_Format dstFmt = + insertConvertTypeOp(output, inputType, elemCount, rewriter, loc); + } // Create the elementwise operation // TODO: Fix attribute @@ -577,8 +1071,13 @@ struct ElementwiseConversion : public OpConversionPattern { rewriter.getI16IntegerAttr(srcFmt)); insertRestoreTypeOp(input0, inputType, elemCount, rewriter, loc); - insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); - insertRestoreTypeOp(output, inputType, elemCount, rewriter, loc); + if (adaptor.getInputs()[0] != adaptor.getInputs()[1]) { + insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); + } + if (adaptor.getInputs()[0] != adaptor.getOutputs()[0] && + adaptor.getInputs()[1] != adaptor.getOutputs()[0]) { + insertRestoreTypeOp(output, inputType, elemCount, rewriter, loc); + } rewriter.eraseOp(op); return success(); @@ -634,13 +1133,20 @@ struct ElementwiseConversion : public OpConversionPattern { auto inputType = dyn_cast(op.getInputs()[0].getType()); + Data_Format srcFmt = + insertConvertTypeOp(input0, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); + // Create the elementwise operation // TODO: Fix attribute - rewriter.create( - loc, rewriter.getI64Type(), input0, input1, output, elemCount, - rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format + rewriter.create(loc, rewriter.getI64Type(), input0, input1, output, + elemCount, + rewriter.getI16IntegerAttr(srcFmt) // Format ); + insertRestoreTypeOp(input0, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); + rewriter.eraseOp(op); return success(); } @@ -672,6 +1178,207 @@ struct ElementwiseConversion : public OpConversionPattern { return success(); } + LogicalResult SelectConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto input2 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[2]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[1].getType()); + assert(inputType.getNumElements() % 8 == 0 && + "Total elements need be multiple of 8\n"); + + Data_Format srcFmt = + insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(input2, inputType, elemCount, rewriter, loc); + + // Add zero const value + auto zero = rewriter.create(op.getLoc(), 0, + rewriter.getI32Type()); + + assert(adaptor.getInputs()[2] != adaptor.getOutputs()[0]); + // Maskmove mask only support int8/fp, here mask is memref + auto maskCast = rewriter.create(loc, inputType); + auto maskCastAddr = createAddressFromMemref(rewriter, loc, maskCast); + rewriter.create(op.getLoc(), rewriter.getI64Type(), input0, + maskCastAddr, elemCount, + rewriter.getI16IntegerAttr(srcFmt)); + + // Input1 and output are same address + if (adaptor.getInputs()[1] == adaptor.getOutputs()[0]) { + // Create memref::allocOp + auto temp = rewriter.create(loc, inputType); + auto tempAddr = createAddressFromMemref(rewriter, loc, temp); + auto mid = rewriter.create( + op.getLoc(), rewriter.getI64Type(), input2, zero, tempAddr, elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(srcFmt)); + + rewriter.create(loc, rewriter.getI64Type(), input1, + tempAddr, elemCount, maskCastAddr, + rewriter.getI32IntegerAttr(srcFmt)); + // Res = input2 + 0; + rewriter.create(op.getLoc(), rewriter.getI64Type(), tempAddr, + zero, output, elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(srcFmt)); + + } else { + insertConvertTypeOp(output, inputType, elemCount, rewriter, loc); + + // Res = input2 + 0; + auto mid = rewriter.create( + op.getLoc(), rewriter.getI64Type(), input2, zero, output, elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(srcFmt)); + + // if input0 = 1, Res = input1; + // if input0 = 0, Res = input2; + rewriter.create(loc, rewriter.getI64Type(), input1, + output, elemCount, maskCastAddr, + rewriter.getI32IntegerAttr(srcFmt)); + insertRestoreTypeOp(output, inputType, elemCount, rewriter, loc); + } + insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(input2, inputType, elemCount, rewriter, loc); + + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult convertMinMaxOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + + auto lhs = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto rhs = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + auto fmt = getFormatCode(inputType); + + auto isANanBuffer = rewriter.create(loc, inputType); + + auto isANanBufferAddr = + createAddressFromMemref(rewriter, loc, isANanBuffer); + + // auto isANan = UnEqualVV(lhs,lhs) + // auto result = lhs + // result = maskmove(isANan, rhs) + // auto isBNan = UnEqualVV(rhs, rhs) + // auto shouldApplyMinMax = EqualVS(isBNan, 0) + // auto minMaxValue = maxvv/minvv(result, rhs) + // result = maskmove(shouldApplyMinMax, minMaxValue) + + auto isANan = + rewriter.create(loc, // loc + rewriter.getI64Type(), // result type + lhs, // input0 + lhs, // input1 + isANanBufferAddr, // out + elemCount, // elem_count + rewriter.getI16IntegerAttr(fmt) // fmt + ); + + auto constValue = rewriter.create( + op.getLoc(), 0, rewriter.getI32Type()); + rewriter.create(op.getLoc(), rewriter.getI64Type(), lhs, + constValue, output, elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(fmt)); + + rewriter.create(loc, rewriter.getI64Type(), rhs, output, + elemCount, isANanBufferAddr, + rewriter.getI32IntegerAttr(fmt)); + + auto isBNanBuffer = rewriter.create(loc, inputType); + auto isBNanBufferAddr = + createAddressFromMemref(rewriter, loc, isBNanBuffer); + auto isBNan = + rewriter.create(loc, // loc + rewriter.getI64Type(), // result type + rhs, // input0 + rhs, // input1 + isBNanBufferAddr, // out + elemCount, // elem_count + rewriter.getI16IntegerAttr(fmt) // fmt + ); + + auto shouldApplyMinMaxBuffer = + rewriter.create(loc, inputType); + auto shouldApplyMinMaxBufferAddr = + createAddressFromMemref(rewriter, loc, shouldApplyMinMaxBuffer); + auto shouldApplyMinMax = rewriter.create( + op.getLoc(), rewriter.getI64Type(), isBNanBufferAddr, constValue, + shouldApplyMinMaxBufferAddr, elemCount, + rewriter.getI16IntegerAttr(fmt)); + + auto minMaxValueBuffer = rewriter.create(loc, inputType); + auto minMaxValueBufferAddr = + createAddressFromMemref(rewriter, loc, isBNanBuffer); + auto minMaxValue = + rewriter.create(loc, rewriter.getI64Type(), output, rhs, + minMaxValueBufferAddr, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(fmt)); + + auto result = rewriter.create( + loc, rewriter.getI64Type(), minMaxValueBufferAddr, output, elemCount, + shouldApplyMinMaxBufferAddr, rewriter.getI32IntegerAttr(fmt)); + + rewriter.eraseOp(op); + + return success(); + } + + LogicalResult + convertCeilAndFloorOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + auto *body = op.getBody(); + auto &operation = body->front(); + auto roundMode = llvm::dyn_cast(&operation) + ? RND_MODE::RND_POS_INF + : RND_MODE::RND_NEG_INF; + // TODO: Fix attribute + // Use IEEE round to positive infinity mode + auto fpToInt = rewriter.create( + loc, + rewriter.getI64Type(), // Result type + input, // Input + output, // Output + elemCount, // Element count + rewriter.getI16IntegerAttr(roundMode) // Round mode + ); + auto intToFp = rewriter.create( + loc, + rewriter.getI64Type(), // Result type + output, // Input + output, // Output + elemCount, // Element count + rewriter.getI16IntegerAttr(0) // Round mode + ); + + rewriter.eraseOp(op); + return success(); + } LogicalResult matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -682,7 +1389,15 @@ struct ElementwiseConversion : public OpConversionPattern { if (op.getIteratorTypesArray().front() != utils::IteratorType::parallel) return rewriter.notifyMatchFailure(op, "Only support elementwise op."); - if (regionOps.size() != 1) { + // WORKAROUND: Select op input0 is bool(i1), cmp op result is bool(i1) + // I64/F64 lowering to llvm + if (regionOps.size() != 1 || + (dyn_cast(op.getOutputs()[0].getType()) + .getElementType() + .getIntOrFloatBitWidth() == 64) || + (dyn_cast(op.getInputs()[0].getType()) + .getElementType() + .getIntOrFloatBitWidth() == 64)) { if (failed(linalg::linalgOpToLoops(rewriter, op))) return rewriter.notifyMatchFailure(op, "Element-wise op not yet supported"); @@ -706,17 +1421,26 @@ struct ElementwiseConversion : public OpConversionPattern { [&](auto elemWiseOp) { return convertBinaryOp(op, adaptor, rewriter); }) - .Case([&](auto elemWiseOp) { - return convertBinaryOp(op, adaptor, rewriter); + .Case( + [&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case( + [&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertMinMaxOp(op, adaptor, rewriter); }) - .Case([&](auto elemWiseOp) { - return convertBinaryOp(op, adaptor, rewriter); + .Case([&](auto elemWiseOp) { + return convertMinMaxOp(op, adaptor, rewriter); }) .Case([&](auto elemWiseOp) { return convertUnaryOp(op, adaptor, rewriter); }) + .Case([&](auto elemWiseOp) { + return convertCeilAndFloorOp(op, adaptor, rewriter); + }) .Case([&](auto elemWiseOp) { return convertUnaryOp(op, adaptor, rewriter); }) @@ -747,11 +1471,14 @@ struct ElementwiseConversion : public OpConversionPattern { .Case([&](auto elemWiseOp) { return FmaConvertOp(op, adaptor, rewriter); }) + .Case([&](auto elemWiseOp) { + return SelectConvertOp(op, adaptor, rewriter); + }) .Case([&](auto elemWiseOp) { // TODO: Need add more int to fp convert. - auto inputType = mlir::cast(op.getInputs()[0].getType()) + auto inputType = dyn_cast(op.getInputs()[0].getType()) .getElementType(); - auto outputType = mlir::cast(op.getOutputs()[0].getType()) + auto outputType = dyn_cast(op.getOutputs()[0].getType()) .getElementType(); if (inputType.isInteger(16) && outputType.isF32()) { return RoundConvertOp(op, adaptor, rewriter); @@ -769,9 +1496,9 @@ struct ElementwiseConversion : public OpConversionPattern { }) .Case([&](auto elemWiseOp) { // TODO: Need add more int to fp convert. - auto inputType = mlir::cast(op.getInputs()[0].getType()) + auto inputType = dyn_cast(op.getInputs()[0].getType()) .getElementType(); - auto outputType = mlir::cast(op.getOutputs()[0].getType()) + auto outputType = dyn_cast(op.getOutputs()[0].getType()) .getElementType(); if (inputType.isF16() && outputType.isInteger(8)) { return RoundConvertOp(op, adaptor, rewriter); @@ -791,10 +1518,17 @@ struct ElementwiseConversion : public OpConversionPattern { "integer conversion"); } }) -// FIXME: Now BoolLessThenOp run fail on board. Need more op information from -// Tx81 -#if 0 .Case([&](auto elemWiseOp) { + // WORKAROUND: Tx8 bool relation op need elems to be multiple of 8. + if (dyn_cast(op.getOperandTypes()[0]).getNumElements() % + 8 != + 0) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure( + op, "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + } arith::CmpIPredicate predicate = elemWiseOp.getPredicate(); switch (predicate) { case arith::CmpIPredicate::eq: @@ -815,7 +1549,43 @@ struct ElementwiseConversion : public OpConversionPattern { break; } }) -#endif + .Case([&](auto elemWiseOp) { + // WORKAROUND: Tx8 bool relation op need elems to be multiple of 8. + if (dyn_cast(op.getOperandTypes()[0]).getNumElements() % + 8 != + 0) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure( + op, "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + } + arith::CmpFPredicate predicate = elemWiseOp.getPredicate(); + switch (predicate) { + case arith::CmpFPredicate::OEQ: + case arith::CmpFPredicate::UEQ: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpFPredicate::ONE: + case arith::CmpFPredicate::UNE: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpFPredicate::OGE: + case arith::CmpFPredicate::UGE: + return BoolRelationVVOp(op, adaptor, + rewriter); + case arith::CmpFPredicate::OGT: + case arith::CmpFPredicate::UGT: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpFPredicate::OLE: + case arith::CmpFPredicate::ULE: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpFPredicate::OLT: + case arith::CmpFPredicate::ULT: + return BoolRelationVVOp(op, adaptor, rewriter); + default: + llvm_unreachable("Not yet supported"); + break; + } + }) .Case([&](auto elemWiseOp) { if (resultType.isF16()) return RoundConvertOp(op, adaptor, rewriter); @@ -832,8 +1602,8 @@ struct ElementwiseConversion : public OpConversionPattern { // other unsupported case now (eg: arith::extf) // TODO: Lower ops to tx81 if is supported - // Affine dialect should handled before this pass. So here lower it to - // scf.for + // Affine dialect should handled before this pass. So here lower it + // to scf.for if (failed(linalg::linalgOpToLoops(rewriter, op))) return rewriter.notifyMatchFailure( op, "Element-wise op not yet supported"); @@ -850,8 +1620,8 @@ struct ReduceConversion : public OpConversionPattern { bool isReductionOpSupported(Operation *redOp) const { return isa( - redOp); + arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp, + arith::XOrIOp>(redOp); } template @@ -861,29 +1631,106 @@ struct ReduceConversion : public OpConversionPattern { auto dims = op.getDimensions(); if (dims.size() != 1) return rewriter.notifyMatchFailure(op, "Only support one dim reduce."); + auto dim = dims[0]; - auto input = - createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInputs()[0]); - auto output = - createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInits()[0]); + auto loc = op->getLoc(); auto inputType = dyn_cast(op.getInputs()[0].getType()); auto inputShape = inputType.getShape(); - // TODO: Support any rank - if (inputShape.size() > 1) - return rewriter.notifyMatchFailure(op, "Rank > 1 unsupported yet."); - - if (dim && dim >= inputShape.size()) - return rewriter.notifyMatchFailure(op, - "Dimensions attribute > input rank !"); + auto newShape4D = reshapeReduceShapeTo4d(inputShape, dim); + + // Triton always assume shape is power of 2, we may not need channel norm + int bitWidth = inputType.getElementType().getIntOrFloatBitWidth(); + int alignBase = bitWidth == 8 ? 128 : 64; + + auto input = op.getInputs()[0]; + auto srcPtr = createAddressFromMemref(rewriter, loc, input); + + int c = newShape4D.back(); + int alignedC = c; + int alignedC0 = 0; + // * < 4 aligned to 4; + // * < align base aligned to next power of 2 + // * > align base aligned to nc'hwc_alignbase + nhwc_alignc0 + bool needChannelNorm = + !(c >= 4 && c <= alignBase && c == next_power_of_two_64(c)); + bool reduceCDim = dim == inputShape.size() - 1; + // Need cx aligned + if (needChannelNorm) { + + int cx = c / alignBase; + int c0 = c % alignBase; + + alignedC0 = c0 ? next_power_of_two_64(c0) : 0; + if (c0 < 4 && c0 > 0) { + // If c0 is not zero, we need to align it to 4 + alignedC0 = 4; + } + + alignedC = cx * alignBase + alignedC0; + SmallVector alignedShape(newShape4D.begin(), + newShape4D.end()); + alignedShape.back() = alignedC; + + auto alignedMemType = + MemRefType::get(alignedShape, inputType.getElementType()); + auto alignedAlloc = rewriter.create(loc, alignedMemType); + auto alignedPtr = createAddressFromMemref(rewriter, loc, alignedAlloc); + + auto channelNorm = rewriter.create( + loc, TypeRange({}), srcPtr, alignedPtr, + rewriter.getDenseI64ArrayAttr(newShape4D), + rewriter.getI16IntegerAttr(alignedC0), + rewriter.getI16IntegerAttr(bitWidth)); + srcPtr = alignedPtr; + } - int64_t inputSize = inputShape.empty() ? 1 : inputShape[0]; + auto output = adaptor.getInits()[0]; + Value alignOutputPtr; + bool needDeChannelNorm = reduceCDim || needChannelNorm; + SmallVector outputShape4D; + if (needDeChannelNorm) { + auto outputType = dyn_cast(output.getType()); + SmallVector alignedOutputShape4D; + if (outputType.getRank() == 0) { + outputShape4D = + SmallVector{1, 1, 1, outputType.getNumElements()}; + alignedOutputShape4D = SmallVector{1, 1, 1, 4}; + } else { + auto outputShape = outputType.getShape(); + outputShape4D = reshapeReduceShapeTo4d(outputShape, dim); + alignedOutputShape4D = + reduceCDim + ? SmallVector{1, 1, outputShape4D[2], 4} + : SmallVector{1, 1, outputShape4D[1], alignedC}; + } + + auto alignedMemType = + MemRefType::get(alignedOutputShape4D, inputType.getElementType()); + auto alignedAlloc = rewriter.create(loc, alignedMemType); + auto alignedPtr = createAddressFromMemref(rewriter, loc, alignedAlloc); + alignOutputPtr = alignedPtr; + } else { + alignOutputPtr = createAddressFromMemref(rewriter, loc, output); + } - SmallVector reduceShape = {1, 1, 1, inputSize}; auto format = getFormatCode(inputType); auto reduceOp = rewriter.create( - op->getLoc(), TypeRange{}, input, output, - rewriter.getUI32IntegerAttr(dim), rewriter.getI64ArrayAttr(reduceShape), + op->getLoc(), TypeRange{}, srcPtr, alignOutputPtr, + rewriter.getUI32IntegerAttr(reduceCDim ? 0 /*reduce C dim*/ + : 1 /*reduce W dim*/), + rewriter.getI64ArrayAttr(newShape4D), rewriter.getI16IntegerAttr(format)); + + if (needDeChannelNorm) { + auto outputPtr = createAddressFromMemref(rewriter, loc, output); + alignedC0 = reduceCDim ? 4 : alignedC0; + auto dechannelNorm = rewriter.create( + loc, TypeRange({}), alignOutputPtr, outputPtr, + rewriter.getDenseI64ArrayAttr(outputShape4D), + rewriter.getI16IntegerAttr(alignedC0) /*alignedC0*/, + rewriter.getI16IntegerAttr(bitWidth)); + } + rewriter.replaceOp(op, reduceOp); return success(); } @@ -902,6 +1749,16 @@ struct ReduceConversion : public OpConversionPattern { } auto redOp = reductionOps[0]; + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + if (!isSupportedType(inputType)) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, "operation not supported yet."); + rewriter.eraseOp(op); + return success(); + } + + // TODO: Convert integer to float return llvm::TypeSwitch(redOp) .Case([&](auto redOp) { return convertToReduceOp(op, adaptor, rewriter); @@ -914,18 +1771,393 @@ struct ReduceConversion : public OpConversionPattern { arith::MinUIOp>([&](auto redOp) { return convertToReduceOp(op, adaptor, rewriter); }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not yet supported"); - return failure(); + .Default([&](auto redOp) { + // For other operation, we don't have specific tx81 op, + // so we need to convert it to loops. + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, + "operation not supported yet."); + rewriter.eraseOp(op); + return success(); }); } }; +struct BarrierConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mk::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + rewriter.create(loc); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct PrintConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mk::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // printf scalar value. + if (printScalar(op)) { + if (op.getNumOperands() == 0) { + createRuntimePrintScalarCall(rewriter, op.getPrefix(), std::nullopt); + } else { + createRuntimePrintScalarCall(rewriter, op.getPrefix(), + adaptor.getOperands()[0], op.getHex(), + op.getIsSigned()[0]); + } + rewriter.eraseOp(op); + return success(); + } + + // print memref value. + createPrintMemrefCall(op, rewriter); + + rewriter.eraseOp(op); + return success(); + } + +private: + static std::string getFormatSubstr(Type type, bool hex = false, + std::optional width = std::nullopt, + bool isSigned = false) { + // If the `value` is a pointer, just return %p. + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + (isSigned ? "lli" : "llu"); + else + return prefix + (isSigned ? "i" : "u"); + } + assert(false && "not supported type"); + return ""; + } + + // For printf, need to extend int32 or float64. + static Value printfPromoteValue(RewriterBase &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + bool isUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + if (isUnsigned) { + return b.zext(ui32_ty, value); + } else { + return b.sext(i32_ty, value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + return b.fpext(f64_ty, value); + } + + return value; + } + + static LLVM::LLVMFuncOp + getOrAddPrintFuncDecl(ConversionPatternRewriter &rewriter, + StringRef funcName = "__Print") { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType = {ptr_ty(ctx)}; + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ true); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); + } + + static bool printScalar(mk::PrintOp op) { + // Simply use printf if no operand or the operand is scalar. + if (op.getNumOperands() == 0) + return true; + + assert(op.getNumOperands() == 1); + Type oprType = op.getOperands()[0].getType(); + return (oprType.isIntOrIndexOrFloat() || isa(oprType)); + } + + static void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, + StringRef prefix, + std::optional arg, + bool hex = false, + bool isSigned = false) { + assert(!prefix.empty() && "printf with empty string not supported"); + auto loc = UnknownLoc::get(rewriter.getContext()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << prefix; + if (arg.has_value()) + os << getFormatSubstr(arg.value().getType(), hex, std::nullopt, isSigned); + + llvm::SmallString<64> formatStrNewline(formatStr); + formatStrNewline.push_back('\n'); + formatStrNewline.push_back('\0'); + Value formatStrValue = LLVM::addStringToModule( + loc, rewriter, "printfFormat_", formatStrNewline); + + SmallVector allArgs{formatStrValue}; + if (arg.has_value()) + allArgs.push_back(printfPromoteValue(rewriter, arg.value())); + b.call(getOrAddPrintFuncDecl(rewriter), allArgs); + } + + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtr = LLVM::LLVMPointerType::get(context); + return LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtr, true); + } + + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + StringRef funcName = "__Print") { + auto *context = module.getContext(); + if (module.lookupSymbol(funcName)) + return SymbolRefAttr::get(context, funcName); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, + getPrintfType(context)); + return SymbolRefAttr::get(context, funcName); + } + + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module) { + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, type, true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), 0); + } + + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create(loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), + globalPtr, ArrayRef({cst0, cst0})); + } + + static void createPrintMemrefCall(mk::PrintOp op, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + auto memRefType = llvm::cast(*op->operand_type_begin()); + auto memRefShape = memRefType.getShape(); + Type memElementType = memRefType.getElementType(); + ModuleOp parentModule = op->getParentOfType(); + + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + std::string formatSpecifierStr = getFormatSubstr( + memElementType, op.getHex(), std::nullopt, op.getIsSigned()[0]); + formatSpecifierStr += " \0"; + auto prefix = op.getPrefix(); + std::string prefixNewline = "\n" + prefix.str(); + Value prefixValue = getOrCreateGlobalString( + loc, rewriter, "frmt_prefix" + prefix.str(), + StringRef(prefixNewline.c_str(), 128), parentModule); + Value formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec" + formatSpecifierStr, + StringRef(formatSpecifierStr.c_str(), 8), parentModule); + Value newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); + + // print prefix firstly. + rewriter.create(loc, getPrintfType(context), printfRef, + prefixValue); + + SmallVector loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + for (Operation &nested : *loop.getBody()) + rewriter.eraseOp(&nested); + loopIvs.push_back(loop.getInductionVar()); + + rewriter.setInsertionPointToEnd(loop.getBody()); + + if (i != e - 1) + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + Value elementLoad = + rewriter.create(loc, op.getOperands()[0], loopIvs); + if (elementLoad.getType() == rewriter.getF32Type()) + elementLoad = rewriter.create( + loc, rewriter.getF64Type(), elementLoad); + else if (elementLoad.getType() == rewriter.getI8Type()) + elementLoad = rewriter.create( + loc, rewriter.getI32Type(), elementLoad); + rewriter.create( + loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, elementLoad})); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Legalize magic kernel operations to be convertible to Tx81 operations +// patterns +//===----------------------------------------------------------------------===// +namespace { +struct ElementwiseRewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void initialize() { + // Register conversions from SIOp to FPOp + registerSIOpMapFPOp(); + registerSIOpMapFPOp(); + registerSIOpMapFPOp(); + registerSIOpMapFPOp(); + registerSIOpMapFPOp(); + registerSIOpMapFPOp(); + } + + template void registerSIOpMapFPOp() { + OperationName SIOpName(SIOp::getOperationName(), getContext()); + assert(!SIOpMapFPOpFns.contains(SIOpName) && + "SIOp already registered for conversion to FPOp"); + SIOpMapFPOpFns[SIOpName] = [](OpBuilder &b, Location loc, ValueRange args) { + // Create the floating-point operation + Value fpVal = b.create(loc, b.getF32Type(), args.drop_back()); + b.create(loc, fpVal); + }; + } + + LogicalResult convertSIOpToFPOp(linalg::GenericOp op, Operation *elemWiseOp, + PatternRewriter &rewriter) const { + OperationName SIOpName = elemWiseOp->getName(); + if (!SIOpMapFPOpFns.contains(SIOpName)) + return failure(); + + Location loc = op->getLoc(); + auto inputs = op.getInputs(); + auto output = op.getOutputs()[0]; + auto outputTy = cast(output.getType()); + auto outputEleTy = outputTy.getElementType(); + + assert(op.getOutputs().size() == 1 && + "Elementwise conversion only support single output"); + assert((outputEleTy.isInteger(16) || outputEleTy.isInteger(32) || + outputEleTy.isInteger(64)) && + "Output type must be integer type (16, 32 or 64 bits)"); + assert( + llvm::all_of( + inputs, [&outputTy](Value v) { return v.getType() == outputTy; }) && + "All inputs must have the same type as output"); + + MemRefType fpMemrefTy = + MemRefType::get(outputTy.getShape(), rewriter.getF32Type(), + outputTy.getLayout(), outputTy.getMemorySpace()); + auto rank = fpMemrefTy.getRank(); + auto id = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); + SmallVector idMaps = {id, id}; + SmallVector iterators(rank, + utils::IteratorType::parallel); + SmallVector fpInputs; + for (auto input : inputs) { + auto fpInput = rewriter.create(loc, fpMemrefTy); + rewriter.create( + loc, ValueRange{input}, ValueRange{fpInput}, idMaps, iterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Convert integer to float + Value fpVal = + b.create(loc, b.getF32Type(), args[0]); + b.create(loc, fpVal); + }); + fpInputs.push_back(fpInput); + } + + auto fpOutput = rewriter.create(loc, fpMemrefTy); + rewriter.create( + loc, fpInputs, ValueRange{fpOutput}, op.getIndexingMapsArray(), + op.getIteratorTypesArray(), SIOpMapFPOpFns.at(SIOpName)); + rewriter.replaceOpWithNewOp( + op, ValueRange{fpOutput}, ValueRange{output}, idMaps, iterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Convert float to integer + Value intVal = b.create(loc, outputEleTy, args[0]); + b.create(loc, intVal); + }); + return success(); + } + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + auto regionOps = getRegionOps(op); + if (regionOps.size() != 1) + return failure(); + + return convertSIOpToFPOp(op, regionOps[0], rewriter); + } + +private: + // Map from SIOp to FPOp conversion functions + llvm::DenseMap> + SIOpMapFPOpFns; +}; } // namespace void mlir::triton::populateMKToTx81CanonicalizationPatterns( - RewritePatternSet &patterns) {} + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} void mlir::triton::populateMKToTx81ConversionPatterns( RewritePatternSet &patterns) { @@ -941,9 +2173,16 @@ void mlir::triton::populateMKToTx81ConversionPatterns( // clang-format off patterns.add( + MKSigmoidToTx81SigmoidOpConversion, + ArgMinConversion, + ArgMaxConversion, + GatherConvertPattern, + ElementwiseConversion, + BarrierConversion, + PrintConversion>( patterns.getContext()); // clang-format on } diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp index 371c2faeb..1151a4211 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp @@ -8,6 +8,7 @@ #include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" @@ -39,8 +40,9 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } bool isOperandMemorySpaceSPM(Value operand) { @@ -50,6 +52,8 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { do { if (isa(op)) return true; + else if (isa(op)) + return false; else if (auto forOp = dyn_cast(op)) { // Here we assume that yieldResults (inner loop region) and // loopResults (outer loop region) correspond one-to-one to obtain the @@ -59,12 +63,13 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { auto yieldResults = forOp.getYieldedValues(); mlir::ResultRange loopResults = forOp.getLoopResults().value(); assert(yieldResults.size() == loopResults.size()); - auto idx = std::distance( loopResults.begin(), std::find(loopResults.begin(), loopResults.end(), operand)); operand = yieldResults[idx]; - + if (operand.getDefiningOp() == nullptr) { + operand = forOp.getInitArgs()[idx]; + } } else { operand = op->getOperand(0); } @@ -77,6 +82,13 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { void runOnOperation() override { auto moduleOp = getOperation(); + RewritePatternSet canonicalizePatterns(&getContext()); + triton::populateMKToTx81CanonicalizationPatterns(canonicalizePatterns); + if (failed( + applyPatternsGreedily(moduleOp, std::move(canonicalizePatterns)))) { + signalPassFailure(); + } + // Use to memory::CopyOp to tx dialect op moduleOp->walk([&](Operation *op) { if (isa(op)) { @@ -92,6 +104,24 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { } }); + // Transpose operand b from (K, N) to (N, K) + moduleOp->walk([&](linalg::MatmulOp op) { + OpBuilder builder(op); + + auto b = op->getOperand(1); + + auto memType = cast(b.getType()); + auto oldShape = memType.getShape(); + llvm::SmallVector newShape({oldShape[1], oldShape[0]}); + auto transposeInit = builder.create( + op->getLoc(), MemRefType::get(newShape, memType.getElementType())); + + SmallVector perm({1, 0}); + auto linalgTranspose = builder.create( + op->getLoc(), b, transposeInit, perm); + op->setOperand(1, transposeInit); + }); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -100,12 +130,19 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { bufferization::BufferizationDialect, mk::MagicKernelDialect>(); - target.addLegalDialect(); + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + affine::AffineDialect, scf::SCFDialect, memref::MemRefDialect, + cf::ControlFlowDialect, tx::Tx81Dialect, LLVM::LLVMDialect>(); + + // FIXME: Support copy rank > 4. Spm to Spm copy has supported + target.addDynamicallyLegalOp([&](memref::CopyOp op) { + auto shape = op.getSource().getType().getShape(); + return shape.size() > 4 && + !(op->getAttrOfType("srcSpm").getInt() && + op->getAttrOfType("dstSpm").getInt()); + }); - target.addIllegalOp(); target.addLegalOp(); triton::populateMKToTx81ConversionPatterns(patterns); diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index b5e1165a7..36d336c28 100644 --- a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -110,6 +110,9 @@ struct MakeTensorPtrConverter auto strideIntAttr = getIntAttr(stride); if (size == 1 && strideIntAttr && strideIntAttr.value() == 0) { strides.push_back(b.getIndexAttr(accumulate)); + } else if (auto v = llvm::dyn_cast_if_present(stride)) { + OpFoldResult result = getAsOpFoldResult(v); + strides.push_back(result); } else { strides.push_back(stride); } @@ -179,8 +182,9 @@ struct MakeTensorPtrConverter /* result shape */ SmallVector{ - // Row stays the same - resultShape[0], + // Row stays the same, but mlir doesn't allow this anymore. Put + // dynamic. + ShapedType::kDynamic, // Column is dynamic, in most cases, this // should be the same as the original column. @@ -288,9 +292,9 @@ struct MakeTensorPtrConverter // around. ShapedType::kDynamic, - // Col stays the same. - resultShape[1], - }); + // Col stays the same, which is resultShape[1], but mlir doesn't + // allow this anymore. So we put dynamic instead. + ShapedType::kDynamic}); Value rowSize = rewriter.create( loc, rewriter.getIndexAttr(op.getSizes()[0])); diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index 7decf7148..54d0e34e0 100644 --- a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -164,9 +164,12 @@ struct ScalarAddptrConverter } }; -static std::optional> -buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) { +static SmallVector buildCastAndOffsetOps(OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, + Location loc) { + assert(inputs.size() == 1 && "Unexpected number of inputs when converting"); + Value input = inputs[0]; assert(resultTypes.size() == 2 && isa(resultTypes[0]) && isa(resultTypes[1]) && "Unexpected result types when converting addptr"); @@ -334,9 +337,7 @@ class StructuredToMemrefPass // Compute the target materialization, given a value with the pointer type, // convert that value to a pair of {memref, index} type. -#if 0 // FIXME: Incompatible MILR interface converter.addTargetMaterialization(buildCastAndOffsetOps); -#endif patterns.add(converter, context); diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp index 637732fc7..95f27794c 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -55,6 +55,7 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( if (assertToCf) { patterns.add(patterns.getContext()); } + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); @@ -72,8 +73,11 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); populateExternElementwiseOpToMLIROps(patterns); diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index bae1bd6ba..d07798eb5 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -150,8 +150,9 @@ class TritonArithToLinalgPass }); if (pidsToFuncArgs) { - target - .addIllegalOp(); + // Need use tx interface to get pid. + target.addIllegalOp< + /* triton::GetProgramIdOp, */ triton::GetNumProgramsOp>(); } if (addptrToLinalg) { diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index ea6d32593..5c939b7d2 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -67,6 +67,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); populateExternElementwiseOpToMLIROps(patterns); diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 71d694290..4c8ce0c35 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -144,15 +144,14 @@ class TritonToStructuredPass // Compute the target materialization, given a value with the pointer type, // convert that value to a tuple type. -#if 0 // FIXME: Incompatible MILR interface - converter.addTargetMaterialization( - [](OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) -> std::optional> { - return builder - .create(loc, resultTypes, input) - ->getResults(); - }); -#endif + converter.addTargetMaterialization([](OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, + Location loc) -> SmallVector { + return builder + .create(loc, resultTypes, inputs.front()) + ->getResults(); + }); scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); @@ -209,16 +208,14 @@ class TritonToStructuredPass // The return values for this op will be used as the init-args for scf.for. // At the end of pointer analysis, we will use the PtrState to create the // correct offsets, strides, and remove these ops. -#if 0 // FIXME: Incompatible MILR interface converter.addTargetMaterialization([](OpBuilder &builder, - TypeRange resultTypes, Value input, - Location loc) { + TypeRange resultTypes, + ValueRange inputs, Location loc) { auto placeholder = builder.create( - loc, input.getDefiningOp()->getOperand(0)); + loc, inputs.front().getDefiningOp()->getOperand(0)); assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); return placeholder.getResults(); }); -#endif RewritePatternSet patterns(&getContext()); scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt index 9a20a8c10..893fdbd83 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt @@ -4,8 +4,10 @@ add_triton_library(Tx81MemrefToLLVM DEPENDS Tx81MemrefToLLVMConversionPassIncGen + TritonUtils LINK_LIBS PUBLIC + TritonUtils MLIRDialectUtils MLIRIR MLIRPass diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp index a857fbb3e..5deb37124 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "utils/utils.h" #include #include @@ -24,8 +25,6 @@ using namespace mlir; uint64_t spmPointer = 0x10000; namespace { -// Used for kcore load/store data from/to spm -const int64_t spmMappingOffset = 0x30400000; //===----------------------------------------------------------------------===// // Tx81 Custom MemRef Op Conversion Patterns @@ -110,21 +109,29 @@ struct MemrefLoadOrStoreOpLowering : public ConvertOpToLLVMPattern { Value adjustedPtr = dataPtr; if (isSpm) { - auto spmMemoryOffset = rewriter.create( - op.getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(spmMappingOffset)); - auto spmMemoryAddr = rewriter.create( - op.getLoc(), rewriter.getI64Type(), - SmallVector({ptrValue, spmMemoryOffset})); - - auto ptrTy = LLVM::LLVMPointerType::get( + // Get the module for function declarations + auto module = op->template getParentOfType(); + // Types for function declaration + SmallVector argTypes = { + rewriter.getI64Type() // offset + }; + + auto i8PtrTy = LLVM::LLVMPointerType::get( rewriter.getContext(), *ConvertToLLVMPattern::getTypeConverter()->getMemRefAddressSpace( type)); - auto spmMemoryAddrPtr = - rewriter.create(op.getLoc(), ptrTy, spmMemoryAddr); - - adjustedPtr = spmMemoryAddrPtr; + // Declare the function + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "get_spm_memory_mapping_wrapper", + i8PtrTy, argTypes); + + // Create the call to __Rdma + auto spmMemoryAddrPtr = rewriter.create( + op.getLoc(), TypeRange{i8PtrTy}, + "get_spm_memory_mapping_wrapper", // funcPtr, + ValueRange{ptrValue}); + + adjustedPtr = spmMemoryAddrPtr.getResult(); } // Wether need memoryspace cast @@ -311,11 +318,10 @@ class ConvertExtractAlignedPointerAsIndex ConversionPatternRewriter &rewriter) const override { BaseMemRefType sourceTy = extractOp.getSource().getType(); - // FIXME: We want allocated ptr instead of aligned ptr. Value alignedPtr; if (sourceTy.hasRank()) { MemRefDescriptor desc(adaptor.getSource()); - alignedPtr = desc.allocatedPtr(rewriter, extractOp->getLoc()); + alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc()); } else { auto elementPtrTy = LLVM::LLVMPointerType::get( rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); @@ -323,8 +329,9 @@ class ConvertExtractAlignedPointerAsIndex UnrankedMemRefDescriptor desc(adaptor.getSource()); Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); - alignedPtr = UnrankedMemRefDescriptor::allocatedPtr( - rewriter, extractOp->getLoc(), descPtr, elementPtrTy); + alignedPtr = UnrankedMemRefDescriptor::alignedPtr( + rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr, + elementPtrTy); } rewriter.replaceOpWithNewOp( diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp index 9458fdc95..cacf20687 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp @@ -34,6 +34,10 @@ class KernelArgBufferPass : public mlir::triton::KernelArgBufferPassBase { using KernelArgBufferPassBase::KernelArgBufferPassBase; +private: + // Check if the function is a kernel function + bool isKernelFunction(LLVM::LLVMFuncOp func); + public: StringRef getArgument() const final { return "kernel-arg-buffer"; } StringRef getDescription() const final { @@ -53,6 +57,11 @@ class KernelArgBufferPass Type argType, int64_t ¤tOffset); }; +bool KernelArgBufferPass::isKernelFunction(LLVM::LLVMFuncOp func) { + return !(func.getSymName().contains("__Print") || + func.getSymName().contains("get_spm_memory_mapping_wrapper")); +} + Value KernelArgBufferPass::insertKernelArgLoad(OpBuilder &builder, Location loc, Value argsBuffer, Type argType, int64_t ¤tOffset) { @@ -82,6 +91,8 @@ void KernelArgBufferPass::runOnOperation() { // Collect functions to process SmallVector kernelFuncs; for (auto func : module.getOps()) { + if (!isKernelFunction(func)) + continue; kernelFuncs.push_back(func); } // NOTE: We move this pass before tx81-to-llvm pass. diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp index 68c7e75ca..2ecf0b4d4 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -38,6 +38,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "utils/utils.h" #include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE "tx81-to-llvm" @@ -52,6 +53,8 @@ namespace { // Helper Functions //===----------------------------------------------------------------------===// // Crt func name +const char rdmaFuncName[] = "__Rdma"; +const char wdmaFuncName[] = "__Wdma"; const char addVVFuncName[] = "__AddVV"; const char subVVFuncName[] = "__SubVV"; const char mulVVFuncName[] = "__MulVV"; @@ -69,6 +72,8 @@ const char addVSFuncName[] = "__AddVS"; const char subVSFuncName[] = "__SubVS"; const char mulVSFuncName[] = "__MulVS"; const char divVSFuncName[] = "__DivVS"; +const char argMinFuncName[] = "__ArgMin"; +const char argMaxFuncName[] = "__ArgMax"; const char reduceSumFuncName[] = "__ReduceSum"; const char reduceMaxFuncName[] = "__ReduceMax"; const char reduceMinFuncName[] = "__ReduceMin"; @@ -88,7 +93,25 @@ const char boolUnEqualVVFuncName[] = "__BoolUnEqualVV"; const char boolGreaterEqualVVFuncName[] = "__BoolGreaterEqualVV"; const char boolGreaterVVFuncName[] = "__BoolGreaterVV"; const char boolLessEqualVVFuncName[] = "__BoolLessEqualVV"; -const char boolLessVVFuncName[] = "__BoolLessThenVV"; +const char boolLessThenVVFuncName[] = "__BoolLessThenVV"; +const char equalVVFuncName[] = "__EqualVV"; +const char unEqualVVFuncName[] = "__UnEqualVV"; +const char greaterEqualVVFuncName[] = "__GreaterEqualVV"; +const char greaterVVFuncName[] = "__GreaterVV"; +const char lessEqualVVFuncName[] = "__LessEqualVV"; +const char lessThenVVFuncName[] = "__LessThenVV"; +const char boolEqualVSFuncName[] = "__BoolEqualVS"; +const char boolUnEqualVSFuncName[] = "__BoolUnEqualVS"; +const char boolGreaterEqualVSFuncName[] = "__BoolGreaterEqualVS"; +const char boolGreaterVSFuncName[] = "__BoolGreaterVS"; +const char boolLessEqualVSFuncName[] = "__BoolLessEqualVS"; +const char boolLessThenVSFuncName[] = "__BoolLessThenVS"; +const char equalVSFuncName[] = "__EqualVS"; +const char unEqualVSFuncName[] = "__UnEqualVS"; +const char greaterEqualVSFuncName[] = "__GreaterEqualVS"; +const char greaterVSFuncName[] = "__GreaterVS"; +const char lessEqualVSFuncName[] = "__LessEqualVS"; +const char lessThenVSFuncName[] = "__LessThenVS"; const char fp32ToFp16FuncName[] = "__FP32_FP16"; const char fp32ToBf16FuncName[] = "__FP32_BF16"; const char fp32ToTF32FuncName[] = "__FP32_TF32"; @@ -97,34 +120,9 @@ const char orVVFuncName[] = "__OrVV"; const char xorVVFuncName[] = "__XorVV"; const char MaxVVFuncName[] = "__MaxVV"; const char MinVVFuncName[] = "__MinVV"; - -// Function to declare Tx81 runtime function -Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, - StringRef name, Type resultType, - ArrayRef argumentTypes) { - // Check if the function already exists - Operation *funcOp = module.lookupSymbol(name); - if (funcOp) - return builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), name); - - // Create function type - Type funcType = LLVM::LLVMFunctionType::get(resultType, argumentTypes, - /*isVarArg=*/false); - - // Create a function declaration - auto ip = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(module.getBody()); - - builder.create(loc, name, funcType, - LLVM::Linkage::External); - - builder.restoreInsertionPoint(ip); - - // Return function pointer - return builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), name); -} +const char transposeFuncName[] = "__Transpose"; +const char nchw2nhwcFuncName[] = "__Nchw2nhwc"; +const char nhwc2nchwFuncName[] = "__Nhwc2nchw"; static Value adjustElemCountType(ConversionPatternRewriter &rewriter, Location loc, Value elemCount) { @@ -147,6 +145,60 @@ static Value castIndexToInt32(ConversionPatternRewriter &rewriter, Location loc, indexOp); } +static Value createInt32ValueArray(ConversionPatternRewriter &rewriter, + Location loc, SmallVector array) { + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i64Ty = rewriter.getI64Type(); + + // Allocate memory for array + Value rank = rewriter.create( + loc, i64Ty, rewriter.getI64IntegerAttr(array.size())); + auto allocaOp = rewriter.create(loc, i32PtrTy, i32Ty, rank); + + // Store each dimension in the array + for (size_t i = 0; i < array.size(); i++) { + // Create the index + Value idx = rewriter.create( + loc, i64Ty, rewriter.getI32IntegerAttr(i)); + + // Create GEP to get pointer to array element + Value elemPtr = rewriter.create(loc, i32PtrTy, i32Ty, allocaOp, + ArrayRef{idx}); + + // Store the value + rewriter.create(loc, array[i], elemPtr); + } + return allocaOp; +} + +static Value +indexValueArrayToInt32ValueArray(ConversionPatternRewriter &rewriter, + Location loc, ValueRange array) { + + SmallVector arrayValues; + for (size_t i = 0; i < array.size(); i++) { + // Create the dimension value + arrayValues.push_back(castIndexToInt32(rewriter, loc, array[i])); + } + + return createInt32ValueArray(rewriter, loc, arrayValues); +} + +static Value int32ArrayToInt32ValueArray(ConversionPatternRewriter &rewriter, + Location loc, + ArrayRef array) { + + SmallVector arrayValues; + auto i32Ty = rewriter.getI32Type(); + for (size_t i = 0; i < array.size(); i++) { + // Create the dimension value + arrayValues.push_back(rewriter.create( + loc, i32Ty, rewriter.getI32IntegerAttr(array[i]))); + } + return createInt32ValueArray(rewriter, loc, arrayValues); +} + //===----------------------------------------------------------------------===// // Arith Operation Conversion Patterns //===----------------------------------------------------------------------===// @@ -165,15 +217,15 @@ struct ConstantOpConversion : public OpConversionPattern { auto resultType = getTypeConverter()->convertType(op.getResult().getType()); // Handle different attribute types - if (auto intAttr = mlir::dyn_cast(constAttr)) { + if (auto intAttr = dyn_cast(constAttr)) { // Convert integer attribute rewriter.replaceOpWithNewOp(op, resultType, intAttr); return success(); - } else if (auto floatAttr = mlir::dyn_cast(constAttr)) { + } else if (auto floatAttr = dyn_cast(constAttr)) { // Convert float attribute rewriter.replaceOpWithNewOp(op, resultType, floatAttr); return success(); - } else if (auto boolAttr = mlir::dyn_cast(constAttr)) { + } else if (auto boolAttr = dyn_cast(constAttr)) { // Convert bool attribute to i1 rewriter.replaceOpWithNewOp( op, resultType, @@ -197,25 +249,23 @@ struct IndexCastOpConversion : public OpConversionPattern { auto dstType = getTypeConverter()->convertType(op.getResult().getType()); // Convert from index to specific integer type - if (mlir::isa(srcType) && - mlir::isa(dstType)) { + if (isa(srcType) && isa(dstType)) { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); return success(); } // Convert from specific integer type to index - if (mlir::isa(srcType) && - mlir::isa(dstType)) { + if (isa(srcType) && isa(dstType)) { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); return success(); } // Handle integer to integer casts - if (mlir::isa(srcType) && mlir::isa(dstType)) { - unsigned srcWidth = mlir::cast(srcType).getWidth(); - unsigned dstWidth = mlir::cast(dstType).getWidth(); + if (isa(srcType) && isa(dstType)) { + unsigned srcWidth = cast(srcType).getWidth(); + unsigned dstWidth = cast(dstType).getWidth(); if (srcWidth < dstWidth) { // Sign extend if source is signed, zero extend otherwise @@ -265,72 +315,104 @@ struct MulIOpConversion : public OpConversionPattern { // Tx81 Operation Conversion Patterns //===----------------------------------------------------------------------===// -// Convert tx81.rdma to LLVM call to crt __Rdma function -struct RdmaOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct BarrierConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tx::RdmaOp op, OpAdaptor adaptor, + matchAndRewrite(tx::BarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get the module for function declarations auto module = op->getParentOfType(); // Declare the __Rdma runtime function if not already declared /* - void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int - shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int - *strides, uint32_t fmt) + void __Barrier() */ + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Declare the function + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__Barrier", i8PtrTy, {}); + + // Create the call to __Rdma + auto call = rewriter.create(op.getLoc(), TypeRange{i8PtrTy}, + "__Barrier", // funcPtr, + ValueRange{}); + + // Replace the op with the call + rewriter.eraseOp(op); + + return success(); + } +}; + +template +struct RdmaWdmaOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the __Rdma runtime function if not already declared + auto i8PtrTy = LLVM::LLVMPointerType::get(ctx); auto i32Ty = rewriter.getI32Type(); - auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32PtrTy = LLVM::LLVMPointerType::get(ctx); // Types for function declaration SmallVector argTypes = { - i8PtrTy, // src - i8PtrTy, // target - i32Ty, // shape_n - i32Ty, // shape_h - i32Ty, // shape_w - i32Ty, // shape_c - i32Ty, // stride_n - i32Ty, // stride_h - i32Ty, // stride_w - i32Ty // fmt + i8PtrTy, // src + i8PtrTy, // target + i32PtrTy, // src_shape array + i32PtrTy, // src_strides array + i32PtrTy, // dst_shape array + i32PtrTy, // dst_strides array + i32Ty, // elemBytes + i32Ty // fmt }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Rdma", - i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, loc, funcPrefix, i8PtrTy, argTypes); // Get the operands Value src = adaptor.getSource(); - src = rewriter.create(op.getLoc(), i8PtrTy, src); + src = rewriter.create(loc, i8PtrTy, src); - // Get the operands Value target = adaptor.getTarget(); - target = rewriter.create(op.getLoc(), i8PtrTy, target); + target = rewriter.create(loc, i8PtrTy, target); - ValueRange shape = adaptor.getShape(); - Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); - Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); - Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); - Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + // Create arrays for shapes and strides - ValueRange strides = adaptor.getStrides(); - Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); - Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); - Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + // Create arrays for shapes and strides + Value srcShapeArray = + indexValueArrayToInt32ValueArray(rewriter, loc, adaptor.getSrcShape()); + Value srcStridesArray = indexValueArrayToInt32ValueArray( + rewriter, loc, adaptor.getSrcStrides()); + Value dstShapeArray = + indexValueArrayToInt32ValueArray(rewriter, loc, adaptor.getDstShape()); + Value dstStridesArray = indexValueArrayToInt32ValueArray( + rewriter, loc, adaptor.getDstStrides()); + + // Handle elem byte attribute + Value elemBytes = rewriter.create( + loc, i32Ty, rewriter.getI32IntegerAttr(op.getElemBytes())); // Handle format attribute Value fmt = rewriter.create( - op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + loc, i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); // Create the call to __Rdma auto call = rewriter.create( - op.getLoc(), TypeRange{i8PtrTy}, "__Rdma", // funcPtr, - ValueRange{src, target, shape0, shape1, shape2, shape3, stride0, - stride1, stride2, fmt}); + loc, TypeRange{i8PtrTy}, funcPrefix, + ValueRange{src, target, srcShapeArray, srcStridesArray, dstShapeArray, + dstStridesArray, elemBytes, fmt}); // Replace the op with the result of the call rewriter.replaceOp(op, call.getResult()); @@ -339,43 +421,35 @@ struct RdmaOpConversion : public OpConversionPattern { } }; -// Convert tx81.wdma to LLVM call to __Wdma function -struct WdmaOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +// Convert tx81.mask_move to LLVM call to __MaskMove function +struct MaskMoveOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tx::WdmaOp op, OpAdaptor adaptor, + matchAndRewrite(tx::MaskMoveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get the module for function declarations auto module = op->getParentOfType(); - // Declare the __Wdma runtime function if not already declared - /* - void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int - shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int - *strides, uint32_t fmt) - */ + // Declare the __MaskMove runtime function if not already declared + // Signature: void* __MaskMove(void* source, void* target, uint32_t + // elem_count, int32_t* masks, uint32_t fmt); auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); // Types for function declaration SmallVector argTypes = { - i8PtrTy, // src - i8PtrTy, // target - i32Ty, // shape_n - i32Ty, // shape_h - i32Ty, // shape_w - i32Ty, // shape_c - i32Ty, // stride_n - i32Ty, // stride_h - i32Ty, // stride_w - i32Ty // fmt + i8PtrTy, // source + i8PtrTy, // target + i32Ty, // elem_count + i32PtrTy, // masks + i32Ty // fmt }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Wdma", - i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__MaskMove", i8PtrTy, argTypes); // Get the operands Value src = adaptor.getSource(); @@ -383,32 +457,27 @@ struct WdmaOpConversion : public OpConversionPattern { // Need to bitcast src to i8* src = rewriter.create(op.getLoc(), i8PtrTy, src); - // Get the operands Value target = adaptor.getTarget(); // Need to bitcast src to i8* target = rewriter.create(op.getLoc(), i8PtrTy, target); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); - ValueRange shape = adaptor.getShape(); - Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); - Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); - Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); - Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + // Handle mask arrays + Value mask = adaptor.getMask(); - ValueRange strides = adaptor.getStrides(); - Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); - Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); - Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + // Need to bitcast src to i8* + mask = rewriter.create(op.getLoc(), i8PtrTy, mask); // Handle format attribute Value fmt = rewriter.create( op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); - // Create the call to __Wdma + // Create the call to __MaskMove auto call = rewriter.create( - op.getLoc(), i8PtrTy, "__Wdma", // funcPtr, - ArrayRef{src, target, shape0, shape1, shape2, shape3, stride0, - stride1, stride2, fmt}); + op.getLoc(), i8PtrTy, "__MaskMove", // funcPtr, + ArrayRef{src, target, elemCount, mask, fmt}); // Replace the op with the result of the call rewriter.replaceOp(op, call.getResult()); @@ -417,64 +486,223 @@ struct WdmaOpConversion : public OpConversionPattern { } }; -// Convert tx81.mask_move to LLVM call to __MaskMove function -struct MaskMoveOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct TransformOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; LogicalResult - matchAndRewrite(tx::MaskMoveOp op, OpAdaptor adaptor, + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + // __Transpose(uint64_t *src, uint64_t *dst, int32_t *src_shape, int32_t + // *dst_shape, uint16_t fmt) + + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32PtrTy, i32PtrTy, + i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSource(); + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value dst = adaptor.getTarget(); + // Need to bitcast src to i8* + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + + // Convert shape attribute to Value + ArrayRef srcShape = adaptor.getSrcShape(); + ArrayRef dstShape = adaptor.getDstShape(); + + // Get shape llvm array + auto srcArray = + int32ArrayToInt32ValueArray(rewriter, op.getLoc(), srcShape); + auto dstArray = + int32ArrayToInt32ValueArray(rewriter, op.getLoc(), dstShape); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{src, dst, srcArray, dstArray, fmt}); + + // Erase the old op + rewriter.eraseOp(op); + + return success(); + } +}; + +struct GatherScatterOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::GatherScatter op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + // Get the module for function declarations auto module = op->getParentOfType(); - // Declare the __MaskMove runtime function if not already declared - // Signature: void* __MaskMove(void* source, void* target, uint32_t - // elem_count, int32_t* masks, uint32_t fmt); + // Declare the __GatherScatter runtime function if not already declared + /* + void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t bytes, + uint32_t src_strideN, uint32_t src_strideH, + uint32_t src_strideW, uint32_t src_iterN, + uint32_t src_iterH, uint32_t src_iterW, + uint32_t dst_strideN, uint32_t dst_strideH, + uint32_t dst_strideW, uint32_t dst_iterN, + uint32_t dst_iterH, uint32_t dst_ite_W) + */ auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); // Types for function declaration SmallVector argTypes = { - i8PtrTy, // source - i8PtrTy, // target - i32Ty, // elem_count - i32PtrTy, // masks - i32Ty // fmt + i8PtrTy, // src + i8PtrTy, // dst + i32Ty, // bytes + i32Ty, // src_StrideN + i32Ty, // src_StrideH + i32Ty, // src_StrideW + i32Ty, // dst_StrideN + i32Ty, // dst_StrideH + i32Ty, // dst_StrideW + i32Ty, // src_IterN + i32Ty, // src_IterH + i32Ty, // src_IterW + i32Ty, // dst_IterN + i32Ty, // dst_IterH + i32Ty // dst_IterW }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - "__MaskMove", i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, loc, "__GatherScatter", i8PtrTy, argTypes); // Get the operands Value src = adaptor.getSource(); + src = rewriter.create(loc, i8PtrTy, src); - // Need to bitcast src to i8* - src = rewriter.create(op.getLoc(), i8PtrTy, src); + // Get the operands + Value dst = adaptor.getTarget(); + dst = rewriter.create(loc, i8PtrTy, dst); + + // Get bytes + auto bytes = + rewriter.create(loc, i32Ty, adaptor.getBytes()); + + // Get strides + auto srcStrideN = + rewriter.create(loc, i32Ty, adaptor.getSrcStrideN()); + auto srcStrideH = + rewriter.create(loc, i32Ty, adaptor.getSrcStrideH()); + auto srcStrideW = + rewriter.create(loc, i32Ty, adaptor.getSrcStrideW()); + auto dstStrideN = + rewriter.create(loc, i32Ty, adaptor.getDstStrideN()); + auto dstStrideH = + rewriter.create(loc, i32Ty, adaptor.getDstStrideH()); + auto dstStrideW = + rewriter.create(loc, i32Ty, adaptor.getDstStrideW()); + + // Get iterator + auto srcIterN = + rewriter.create(loc, i32Ty, adaptor.getSrcIterN()); + auto srcIterH = + rewriter.create(loc, i32Ty, adaptor.getSrcIterH()); + auto srcIterW = + rewriter.create(loc, i32Ty, adaptor.getSrcIterW()); + auto dstIterN = + rewriter.create(loc, i32Ty, adaptor.getDstIterN()); + auto dstIterH = + rewriter.create(loc, i32Ty, adaptor.getDstIterH()); + auto dstIterW = + rewriter.create(loc, i32Ty, adaptor.getDstIterW()); - Value target = adaptor.getTarget(); + // Create the call to __Rdma + auto call = rewriter.create( + loc, TypeRange{i8PtrTy}, "__GatherScatter", // funcPtr, + ValueRange{src, dst, bytes, srcStrideN, srcStrideH, srcStrideW, + srcIterN, srcIterH, srcIterW, dstStrideN, dstStrideH, + dstStrideW, dstIterN, dstIterH, dstIterW}); + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +template +struct ArgMinMaxOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + + // __ArgMinMax(uint64_t *src, uint64_t *dst0, uint64_t *dst1, + // uint32_t elem_count, uint16_t fmt) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSrc(); // Need to bitcast src to i8* - target = rewriter.create(op.getLoc(), i8PtrTy, target); - Value elemCount = adaptor.getElemCount(); - elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); + src = rewriter.create(op.getLoc(), i8PtrTy, src); - // Handle mask arrays - // For simplicity, we'll create empty arrays - Value nullPtr = rewriter.create(op.getLoc(), i32PtrTy); + // Convert results + Value value = adaptor.getValue(); + Value index = adaptor.getIndex(); + // Need to bitcast `value` and `index` to i8* + value = rewriter.create(op.getLoc(), i8PtrTy, value); + index = rewriter.create(op.getLoc(), i8PtrTy, index); + + // Get elem_count operand, convert Index to I32 + Value elemCount = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getElemCount())); // Handle format attribute Value fmt = rewriter.create( - op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); - // Create the call to __MaskMove + // Create the call auto call = rewriter.create( - op.getLoc(), i8PtrTy, "__MaskMove", // funcPtr, - ArrayRef{src, target, elemCount, nullPtr, fmt}); + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{src, value, index, elemCount, fmt}); - // Replace the op with the result of the call - rewriter.replaceOp(op, call.getResult()); + // Erase the old op + rewriter.eraseOp(op); return success(); } @@ -504,8 +732,8 @@ struct ReduceOpConversion : public OpConversionPattern { SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty, i16Ty, i16Ty, i16Ty, i16Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value src = adaptor.getSrc(); @@ -571,8 +799,8 @@ struct ElementWiseOpConversion : public OpConversionPattern { i32Ty, i32Ty, i32Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value srcA = adaptor.getInput0(); @@ -630,8 +858,8 @@ struct UnaryOpConversion : public OpConversionPattern { // Types for function declaration SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value input = adaptor.getInput(); @@ -684,8 +912,8 @@ struct BinaryVSOpConversion : public OpConversionPattern { SmallVector argTypes = {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value srcA = adaptor.getInput0(); @@ -748,8 +976,8 @@ struct BinaryLogicVVOpConversion : public OpConversionPattern { i32Ty // fmt }; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value srcA = adaptor.getInput0(); @@ -782,15 +1010,14 @@ struct BinaryLogicVVOpConversion : public OpConversionPattern { } }; -template -struct BoolRelationVVOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename BoolRelationVVOp::Adaptor; +template +struct RelationVVOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename RelationVVOp::Adaptor; // using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BoolRelationVVOp op, OpAdaptor adaptor, + matchAndRewrite(RelationVVOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get the module for function declarations auto module = op->template getParentOfType(); @@ -805,8 +1032,8 @@ struct BoolRelationVVOpConversion // Types for function declaration SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, i32Ty, i16Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value srcA = adaptor.getInput0(); @@ -839,6 +1066,62 @@ struct BoolRelationVVOpConversion } }; +// FIXME: Use trait to refactor the RelationVSOpConversion and +// ElementWiseOpConversion +template +struct RelationVSOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __BoolEqualVS(uint64_t *src0, uint32_t src1, uint64_t + // *dst,uint32_t elem_count, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + + Value srcB = adaptor.getValue(); + + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + // Convert tx81.NormalConvertOp op to LLVM template struct NormalConvertOpConversion : public OpConversionPattern { @@ -860,8 +1143,8 @@ struct NormalConvertOpConversion : public OpConversionPattern { // Types for function declaration SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value input = adaptor.getInput(); @@ -907,8 +1190,8 @@ struct RoundConvertOpConversion : public OpConversionPattern { // Types for function declaration SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - funcPrefix, i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); // Convert operands Value input = adaptor.getInput(); @@ -934,6 +1217,186 @@ struct RoundConvertOpConversion : public OpConversionPattern { } }; +struct BitToFPOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = tx::Bit2FpOp::Adaptor; + + LogicalResult + matchAndRewrite(tx::Bit2FpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __Bit2Fp(uint64_t *src, uint64_t *target, uint32_t + // elem_count, uint16_t fmt) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__Bit2Fp", i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getSrc(); + Value output = adaptor.getTarget(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Bit2Fp", // funcPtr, + ArrayRef{input, output, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.channel_norm op +struct ChannelNormOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename tx::ChannelNormOp::Adaptor; + + LogicalResult + matchAndRewrite(tx::ChannelNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + // __ChannelNorm(uint64_t *src, uint64_t *dst, uint16_t n, + // uint16_t h, uint16_t w, uint16_t c, uint16_t c0, uint16_t + // dtype_size) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i16Ty, i16Ty, + i16Ty, i16Ty, i16Ty, i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__ChannelNorm", i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSrc(); + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value dst = adaptor.getDst(); + // Need to bitcast dst to i8* + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + + // Convert shape attribute to Value + Value shape_n = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[0]); + Value shape_h = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[1]); + Value shape_w = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[2]); + Value shape_c = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[3]); + + // Convert c0_align attribute to Value + Value c0Align = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getC0Align())); + + // Convert dtype_size attribute to Value + Value dtypeSize = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI32IntegerAttr(op.getDtypeSize())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__ChannelNorm", // funcPtr, + ArrayRef{src, dst, shape_n, shape_h, shape_w, shape_c, c0Align, + dtypeSize}); + + // Erase the old op + rewriter.eraseOp(op); + + return success(); + } +}; + +struct DechannelNormOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename tx::DechannelNormOp::Adaptor; + + LogicalResult + matchAndRewrite(tx::DechannelNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + // __DechannelNorm(uint64_t *src, uint64_t *dst, uint16_t n, + // uint16_t h, uint16_t w, uint16_t c, uint16_t c0, uint16_t + // dtype_size) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i16Ty, i16Ty, + i16Ty, i16Ty, i16Ty, i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__DechannelNorm", i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSrc(); + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value dst = adaptor.getDst(); + // Need to bitcast dst to i8* + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + + // Convert shape attribute to Value + Value shape_n = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[0]); + Value shape_h = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[1]); + Value shape_w = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[2]); + Value shape_c = + rewriter.create(op.getLoc(), i16Ty, op.getShape()[3]); + + // Convert c0_align attribute to Value + Value c0Align = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getC0Align())); + + // Convert dtype_size attribute to Value + Value dtypeSize = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI32IntegerAttr(op.getDtypeSize())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__DechannelNorm", // funcPtr, + ArrayRef{src, dst, shape_n, shape_h, shape_w, shape_c, c0Align, + dtypeSize}); + + // Erase the old op + rewriter.eraseOp(op); + + return success(); + } +}; + // Convert tx81.gemm to LLVM call to __Gemm function struct GemmOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -981,8 +1444,8 @@ struct GemmOpConversion : public OpConversionPattern { }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Gemm", - i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__Gemm", i8PtrTy, argTypes); // Convert operands Value srcA = adaptor.getSrcA(); @@ -1010,15 +1473,15 @@ struct GemmOpConversion : public OpConversionPattern { auto dimsAttr = op.getDims(); SmallVector dimsValues; for (auto dimAttr : dimsAttr) - dimsValues.push_back(mlir::cast(dimAttr).getInt()); + dimsValues.push_back(cast(dimAttr).getInt()); // Allocate memory for the dims array - Value dimsArraySize = rewriter.create( + Value rank = rewriter.create( op.getLoc(), i64Ty, rewriter.getI64IntegerAttr(dimsValues.size())); // Use alloc to allocate memory for dims array auto dimsArrayI32Ptr = rewriter.create( - op.getLoc(), i32PtrTy, rewriter.getI32Type(), dimsArraySize, + op.getLoc(), i32PtrTy, rewriter.getI32Type(), rank, /*alignment=*/0); // Store each dimension in the array @@ -1080,6 +1543,54 @@ struct GemmOpConversion : public OpConversionPattern { } }; +struct SigmoidOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::Sigmoid op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Sigmoid runtime function if not already declared + // Signature: void __Sigmoid(int64_t* src, int64_t *dst, + // uint32_t elem_count, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__Sigmoid", i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + Value output = adaptor.getOut(); + Value elemCount = adaptor.getElemCount(); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Sigmoid", // funcPtr, + ArrayRef{input, output, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + // Convert tx81.memset to LLVM call to __Memset function struct MemsetOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1114,8 +1625,8 @@ struct MemsetOpConversion : public OpConversionPattern { }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - "__Memset", i8PtrTy, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__Memset", i8PtrTy, argTypes); // Get operands Value src = adaptor.getSrc(); @@ -1152,321 +1663,6 @@ struct MemsetOpConversion : public OpConversionPattern { } }; -// Conversion pattern for linalg.fill operation with tensor arguments -struct LinalgFillOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LinalgFillOpConversion(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(linalg::FillOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // The operation should have tensor as output - if (op.getOutputs().size() != 1) { - return rewriter.notifyMatchFailure(op, "expects single output tensor"); - } - - // Check if the output is a tensor type - Value outputTensor = op.getOutputs()[0]; - auto tensorType = mlir::dyn_cast(outputTensor.getType()); - if (!tensorType) { - return rewriter.notifyMatchFailure(op, "expects ranked tensor type"); - } - - // Check for static shape - if (!tensorType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "dynamic shapes not yet supported"); - } - - auto context = rewriter.getContext(); - auto loc = op.getLoc(); - Value value = adaptor.getInputs()[0]; - - // Get the element type - Type elemType = tensorType.getElementType(); - - // Convert the tensor type to the LLVM pointer type - auto llvmPtrType = mlir::dyn_cast( - typeConverter->convertType(tensorType)); - if (!llvmPtrType) { - return rewriter.notifyMatchFailure( - op, "failed to convert tensor type to LLVM pointer type"); - } - - // Calculate total number of elements - int64_t totalElements = 1; - for (int64_t dim : tensorType.getShape()) { - totalElements *= dim; - } - - // Get index type - auto indexType = rewriter.getI64Type(); - - // Implement the following steps: - // 1. Allocate memory for the tensor - // 2. Fill it using memset if applicable - // 3. Return the pointer as the result - - // Calculate element size in bytes - int64_t elemSizeInBytes = 0; - if (auto intType = mlir::dyn_cast(elemType)) { - elemSizeInBytes = - (intType.getWidth() + 7) / 8; // Round up to nearest byte - } else if (auto floatType = mlir::dyn_cast(elemType)) { - elemSizeInBytes = - (floatType.getWidth() + 7) / 8; // Round up to nearest byte - } else { - return rewriter.notifyMatchFailure(op, "unsupported element type"); - } - - // Calculate total size in bytes - auto totalSizeInBytes = totalElements * elemSizeInBytes; - auto totalSizeVal = rewriter.create( - loc, indexType, rewriter.getI64IntegerAttr(totalSizeInBytes)); - - // Allocate memory - auto mallocFunc = - getOrInsertMalloc(rewriter, op->getParentOfType()); - auto allocated = rewriter.create( - loc, LLVM::LLVMPointerType::get(context), mallocFunc, - ArrayRef{totalSizeVal}); - - auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); - - // Cast the allocated memory to the appropriate pointer type - auto castPtr = rewriter.create(loc, llvmPtrType, - allocated.getResult()); - - // Check if we can use memset for filling - bool useMemset = false; - Value byteValue; - - // For memset to work correctly, we need to have a consistent byte pattern - if (auto constOp = value.getDefiningOp()) { - if (auto intAttr = mlir::dyn_cast(constOp.getValue())) { - // For integer constants - auto intVal = intAttr.getInt(); - // Check if all bytes in the pattern are the same - bool allBytesEqual = true; - uint8_t firstByte = intVal & 0xFF; - for (unsigned i = 1; i < elemSizeInBytes; i++) { - if (((intVal >> (i * 8)) & 0xFF) != firstByte) { - allBytesEqual = false; - break; - } - } - - if (allBytesEqual) { - useMemset = true; - byteValue = rewriter.create( - loc, rewriter.getIntegerType(8), - rewriter.getIntegerAttr(rewriter.getIntegerType(8), firstByte)); - } - } else if (auto floatAttr = - mlir::dyn_cast(constOp.getValue())) { - // For floating point constants - if (floatAttr.getValue().isZero()) { - // Zero float can use memset with zero byte value - useMemset = true; - byteValue = rewriter.create( - loc, rewriter.getIntegerType(8), rewriter.getI8IntegerAttr(0)); - } - } - } - - if (useMemset) { - // Use memset for filling - auto memsetFunc = - getOrInsertMemset(rewriter, op->getParentOfType()); - rewriter.create( - loc, llvmVoidPtr, memsetFunc, - ArrayRef{castPtr, byteValue, totalSizeVal}); - } else { - // Create a loop to manually fill the tensor with the value - // We'll use SCF dialect for structured loops - auto llvmElemType = typeConverter->convertType(elemType); - - // Create loop initialization - auto zero = rewriter.create( - loc, indexType, rewriter.getI64IntegerAttr(0)); - auto upperBound = rewriter.create( - loc, indexType, rewriter.getI64IntegerAttr(totalElements)); - auto one = rewriter.create( - loc, indexType, rewriter.getI64IntegerAttr(1)); - - // Create the fill loop - auto loopOp = - rewriter.create(loc, zero, upperBound, one, ValueRange{}); - - // Set insertion point inside the loop - rewriter.setInsertionPointToStart(loopOp.getBody()); - - // Calculate pointer for the current element - auto currentPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(context), - LLVM::LLVMPointerType::get(context), castPtr, - ArrayRef({loopOp.getInductionVar()})); - - // Store the fill value to the current memory location - rewriter.create(loc, value, currentPtr); - - // Reset insertion point after the loop - rewriter.setInsertionPointAfter(loopOp); - } - - // Replace the original op with the casted pointer - rewriter.replaceOp(op, castPtr); - return success(); - } - -private: - // Helper to get or insert malloc function declaration - FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, - ModuleOp module) const { - auto context = rewriter.getContext(); - auto mallocName = "malloc"; - if (module.lookupSymbol(mallocName)) { - return SymbolRefAttr::get(rewriter.getContext(), mallocName); - } - - // Create malloc function declaration - auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); - auto mallocType = - LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, - /*isVarArg=*/false); - - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module->getLoc(), mallocName, mallocType); - - return SymbolRefAttr::get(rewriter.getContext(), mallocName); - } - - // Helper to get or insert memset function declaration - FlatSymbolRefAttr getOrInsertMemset(PatternRewriter &rewriter, - ModuleOp module) const { - auto context = rewriter.getContext(); - auto memsetName = "memset"; - if (module.lookupSymbol(memsetName)) { - return SymbolRefAttr::get(rewriter.getContext(), memsetName); - } - - // Create memset function declaration - auto voidPtrType = LLVM::LLVMPointerType::get(context); - auto memsetType = LLVM::LLVMFunctionType::get( - context, - voidPtrType, // memset returns the destination pointer - ArrayRef{voidPtrType, rewriter.getI8Type(), - rewriter.getI64Type()}, - /*isVarArg=*/false); - - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module->getLoc(), memsetName, memsetType); - - return SymbolRefAttr::get(rewriter.getContext(), memsetName); - } -}; - -// Conversion pattern for tensor.empty operation -class TensorEmptyOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - TensorEmptyOpConversion(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Get the result tensor type - TensorType resultType = op.getType(); - - // Verify we can handle this tensor type - if (!resultType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "dynamic shapes not yet supported"); - } - - // Convert the tensor type to LLVM pointer type - auto llvmPtrType = mlir::dyn_cast( - getTypeConverter()->convertType(resultType)); - - if (!llvmPtrType) { - return rewriter.notifyMatchFailure( - op, "failed to convert tensor type to LLVM pointer type"); - } - - // Get element type - Type elementType = resultType.getElementType(); - - // Create LLVM operations to allocate memory - // 1. Calculate the total allocation size in bytes - auto loc = op.getLoc(); - int64_t totalElements = 1; - for (int64_t dim : resultType.getShape()) { - totalElements *= dim; - } - - auto elementSize = rewriter.create( - loc, rewriter.getI64Type(), - rewriter.getI64IntegerAttr(getElementTypeSize(elementType))); - - auto totalSize = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(totalElements)); - - auto allocSize = rewriter.create(loc, rewriter.getI64Type(), - totalSize, elementSize); - - // 2. Allocate memory using malloc - auto mallocFunc = - getOrInsertMalloc(rewriter, op->getParentOfType()); - auto allocated = rewriter.create(loc, llvmPtrType, mallocFunc, - ArrayRef{allocSize}); - - // Replace the tensor.empty operation with our allocation - rewriter.replaceOp(op, allocated.getResult()); - return success(); - } - -private: - // Helper to get element type size in bytes - int64_t getElementTypeSize(Type type) const { - if (auto floatType = mlir::dyn_cast(type)) { - return floatType.getWidth() / 8; - } else if (auto intType = mlir::dyn_cast(type)) { - return intType.getWidth() / 8; - } - // Default for other types - return 1; - } - - // Helper to get or insert malloc function declaration - FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, - ModuleOp module) const { - auto mallocName = "malloc"; - if (module.lookupSymbol(mallocName)) { - return SymbolRefAttr::get(rewriter.getContext(), mallocName); - } - - // Create malloc function declaration - auto llvmVoidPtr = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto mallocType = - LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, - /*isVarArg=*/false); - - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module->getLoc(), mallocName, mallocType); - - return SymbolRefAttr::get(rewriter.getContext(), mallocName); - } -}; - // Convert tt.get_program_id to LLVM call to __get_pid function // Think this as Tx81 special action. May can separate to a single pass or use // tx81.get_program_id op @@ -1492,8 +1688,8 @@ struct GetProgramIDConversion }; // Declare the function - Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), - "__get_pid", i32Ty, argTypes); + Value funcPtr = triton::utils::declareTx81Function( + module, rewriter, op.getLoc(), "__get_pid", i32Ty, argTypes); // Get operands auto axis = (uint32_t)op.getAxis(); @@ -1572,6 +1768,8 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { RoundConvertOpConversion, RoundConvertOpConversion, RoundConvertOpConversion, + ArgMinMaxOpConversion, + ArgMinMaxOpConversion, ReduceOpConversion, ReduceOpConversion, ReduceOpConversion, @@ -1594,21 +1792,48 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { BinaryVSOpConversion, BinaryVSOpConversion, BinaryVSOpConversion, - BoolRelationVVOpConversion, - BoolRelationVVOpConversion, - BoolRelationVVOpConversion, - BoolRelationVVOpConversion, - BoolRelationVVOpConversion, - BoolRelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVVOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, + RelationVSOpConversion, BinaryLogicVVOpConversion, BinaryLogicVVOpConversion, BinaryLogicVVOpConversion, - RdmaOpConversion, - WdmaOpConversion, + RdmaWdmaOpConversion, + RdmaWdmaOpConversion, + TransformOpConversion, + TransformOpConversion, + TransformOpConversion, MaskMoveOpConversion, + GatherScatterOpConversion, + BitToFPOpConversion, + ChannelNormOpConversion, + DechannelNormOpConversion, GemmOpConversion, + SigmoidOpConversion, MemsetOpConversion, - GetProgramIDConversion>( + GetProgramIDConversion, + BarrierConversion>( context); // clang-format on diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp index d71761179..2cb706195 100644 --- a/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp @@ -11,6 +11,12 @@ using namespace mlir; using namespace mlir::mk; +LogicalResult PrintOp::verify() { + if (getOperands().size() > 1) + return emitOpError("expects at most one operand"); + return success(); +} + /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. void MagicKernelDialect::initialize() { @@ -19,8 +25,8 @@ void MagicKernelDialect::initialize() { #include "magic-kernel/Dialect/IR/MagicKernelOps.cpp.inc" >(); // TODO: Add BufferizableOpInterface to all ops that can be bufferized - declarePromisedInterfaces(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp index f0c256956..552603b77 100644 --- a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp @@ -118,5 +118,8 @@ void mlir::mk::registerBufferizableOpInterfaceExternalModels( +[](MLIRContext *ctx, mlir::mk::MagicKernelDialect *dialect) { // TODO: Register all mk ops. MKOpInterfaceHelper::registerOpInterface(ctx); + MKOpInterfaceHelper::registerOpInterface(ctx); + MKOpInterfaceHelper::registerOpInterface(ctx); + MKOpInterfaceHelper::registerOpInterface(ctx); }); } diff --git a/third_party/tsingmicro/lib/Utils/CMakeLists.txt b/third_party/tsingmicro/lib/Utils/CMakeLists.txt new file mode 100644 index 000000000..6b6dca760 --- /dev/null +++ b/third_party/tsingmicro/lib/Utils/CMakeLists.txt @@ -0,0 +1,3 @@ +add_triton_library(TritonUtils + utils.cpp +) diff --git a/third_party/tsingmicro/lib/Utils/utils.cpp b/third_party/tsingmicro/lib/Utils/utils.cpp new file mode 100644 index 000000000..efec7eb63 --- /dev/null +++ b/third_party/tsingmicro/lib/Utils/utils.cpp @@ -0,0 +1,41 @@ +//===------------------- utils.cpp ----------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "utils/utils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::utils { + +// Function to declare Tx81 runtime function +Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, + StringRef name, Type resultType, + ArrayRef argumentTypes) { + // Check if the function already exists + Operation *funcOp = module.lookupSymbol(name); + if (funcOp) + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); + + // Create function type + Type funcType = LLVM::LLVMFunctionType::get(resultType, argumentTypes, + /*isVarArg=*/false); + + // Create a function declaration + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(module.getBody()); + + builder.create(loc, name, funcType, + LLVM::Linkage::External); + + builder.restoreInsertionPoint(ip); + + // Return function pointer + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); +} + +} // namespace mlir::triton::utils diff --git a/third_party/tsingmicro/python/triton_tsingmicro.cc b/third_party/tsingmicro/python/triton_tsingmicro.cc index 608918898..ff232d947 100644 --- a/third_party/tsingmicro/python/triton_tsingmicro.cc +++ b/third_party/tsingmicro/python/triton_tsingmicro.cc @@ -1,45 +1,7 @@ #include -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "passes.h" - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" -#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" -#include "triton-shared/Conversion/StructuredToMemref/Passes.h" -#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" -#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" -#include "triton-shared/Conversion/TritonToLinalg/Passes.h" -#include "triton-shared/Conversion/TritonToStructured/Passes.h" -#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" -#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" -#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" - -#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" -#include "magic-kernel/Conversion/LinalgToMK/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "triton-shared/Conversion/StructuredToMemref/Passes.h" -#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" -#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" -#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h" - -#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" - -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllExtensions.h" -#include "mlir/InitAllPasses.h" - -#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" -#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" - namespace py = pybind11; -using namespace mlir; +// The TsingMicro backend with ztc doesn't do compilation from within python +// but rather externally through ztc-opt, so we leave this function blank. void init_triton_tsingmicro(py::module &&m) {} diff --git a/third_party/tsingmicro/requirements.txt b/third_party/tsingmicro/requirements.txt new file mode 100644 index 000000000..a4a24b3f4 --- /dev/null +++ b/third_party/tsingmicro/requirements.txt @@ -0,0 +1,4 @@ +gitpython +nanobind +torch==2.7.0 +torchvision diff --git a/third_party/tsingmicro/scripts/build_llvm.sh b/third_party/tsingmicro/scripts/build_llvm.sh index d76a4ef1c..a1ee7e5d5 100755 --- a/third_party/tsingmicro/scripts/build_llvm.sh +++ b/third_party/tsingmicro/scripts/build_llvm.sh @@ -22,6 +22,7 @@ build_llvm() { -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU;RISCV" \ -DLLVM_USE_LINKER=lld \ -DMLIR_ENABLE_BINDINGS_PYTHON=1 \ + -DPython3_EXECUTABLE="$(which python3)" \ ../llvm ninja } diff --git a/third_party/tsingmicro/scripts/build_tsingmicro.sh b/third_party/tsingmicro/scripts/build_tsingmicro.sh index 1e093cf75..3ecc21642 100755 --- a/third_party/tsingmicro/scripts/build_tsingmicro.sh +++ b/third_party/tsingmicro/scripts/build_tsingmicro.sh @@ -8,17 +8,17 @@ if [ -z "${WORKSPACE+x}" ]; then WORKSPACE=$(realpath "$project_dir/..") fi -TX8_HOME=$WORKSPACE/tx8_deps +TX8_DEPS_ROOT=$WORKSPACE/tx8_deps LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 -if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then - WORKSPACE="${HOME}/.flagtree/tsingmicro/" - TX8_HOME=$WORKSPACE/tx8_deps +if [ ! -d $TX8_DEPS_ROOT ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.triton/tsingmicro/" + TX8_DEPS_ROOT=$WORKSPACE/tx8_deps LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 fi -if [ ! -d $TX8_HOME ]; then - echo "Error: $TX8_HOME not exist!" 1>&2 +if [ ! -d $TX8_DEPS_ROOT ]; then + echo "Error: $TX8_DEPS_ROOT not exist!" 1>&2 exit 1 fi @@ -29,23 +29,36 @@ fi BUILD_TYPE=Release -export TX8_HOME=$TX8_HOME +build_triton() { + if [ "x$BUILD_TYPE" == "xDebug" ]; then + export DEBUG=ON + else + export REL_WITH_DBG_INFO=ON + fi + + export TRITON_BUILD_WITH_CLANG_LLD=true + export TRITON_BUILD_WITH_CCACHE=true + export TRITON_OFFLINE_BUILD=ON + export TRITON_BUILD_PROTON=OFF + + echo "export TRITON_OFFLINE_BUILD=$TRITON_OFFLINE_BUILD" + echo "export TRITON_BUILD_WITH_CLANG_LLD=$TRITON_BUILD_WITH_CLANG_LLD" + echo "export TRITON_BUILD_WITH_CCACHE=$TRITON_BUILD_WITH_CCACHE" + echo "export TRITON_BUILD_PROTON=$TRITON_BUILD_PROTON" + + cd $project_dir/python + python3 -m pip install . --no-build-isolation -v --verbose +} + +if [ -f $project_dir/.venv/bin/activate ]; then + source $project_dir/.venv/bin/activate +fi + export LLVM_SYSPATH=$LLVM +export TX8_DEPS_ROOT=$TX8_DEPS_ROOT export FLAGTREE_BACKEND=tsingmicro -export TRITON_OFFLINE_BUILD=ON -export TRITON_BUILD_WITH_CLANG_LLD=true -export TRITON_BUILD_WITH_CCACHE=true -export TRITON_BUILD_PROTON=OFF - -echo "export TX8_HOME=$TX8_HOME" +echo "export TX8_DEPS_ROOT=$TX8_DEPS_ROOT" echo "export LLVM_SYSPATH=$LLVM_SYSPATH" -echo "export FLAGTREE_BACKEND=$FLAGTREE_BACKEND" - -echo "export TRITON_OFFLINE_BUILD=$TRITON_OFFLINE_BUILD" -echo "export TRITON_BUILD_WITH_CLANG_LLD=$TRITON_BUILD_WITH_CLANG_LLD" -echo "export TRITON_BUILD_WITH_CCACHE=$TRITON_BUILD_WITH_CCACHE" -echo "export TRITON_BUILD_PROTON=$TRITON_BUILD_PROTON" -cd python -python3 -m pip install . --no-build-isolation -v --verbose +build_triton diff --git a/third_party/tsingmicro/scripts/build_tx8_deps.sh b/third_party/tsingmicro/scripts/build_tx8_deps.sh new file mode 100755 index 000000000..7c440911a --- /dev/null +++ b/third_party/tsingmicro/scripts/build_tx8_deps.sh @@ -0,0 +1,233 @@ +#!/bin/bash + +# # 定义一个函数,用于克隆 Git 仓库并切换到指定分支或指定 commit +clone_and_checkout() { + local git_url="$1" + local target_dir="$2" + local ref_type="$3" # "branch" 或 "commit" + local ref_value="$4" # 分支名称或 commit ID + + # 检查目标目录是否存在,如果不存在则创建 + if [ ! -d "$target_dir" ]; then + mkdir -p "$target_dir" + if [ $? -ne 0 ]; then + echo "Error: Failed to create target directory: $target_dir" + return 1 + fi + fi + + if [ "$(ls -A $target_dir)" ]; then + echo "jump clone $target_dir" + return 1 + fi + + # 使用 pushd 进入目标目录 + pushd "$target_dir" >/dev/null || return 1 + + # 克隆仓库 + git clone "$git_url" . + if [ $? -ne 0 ]; then + echo "Error: Failed to clone repository: $git_url" + popd >/dev/null + return 1 + fi + + # 根据 ref_type 切换到分支或 commit + if [ "$ref_type" == "branch" ]; then + # 检查分支是否存在 + if ! git branch -r | grep -q "origin/$ref_value"; then + echo "Error: Branch '$ref_value' does not exist in the repository." + popd >/dev/null + return 1 + fi + # 切换到指定分支 + git checkout "$ref_value" + if [ $? -ne 0 ]; then + echo "Error: Failed to switch to branch: $ref_value" + popd >/dev/null + return 1 + fi + echo "Successfully cloned and switched to branch '$ref_value' for repository: $git_url" + elif [ "$ref_type" == "commit" ]; then + # 检查 commit 是否存在 + if ! git rev-parse "$ref_value" >/dev/null 2>&1; then + echo "Error: Commit '$ref_value' does not exist in the repository." + popd >/dev/null + return 1 + fi + # 切换到指定 commit + git checkout "$ref_value" + if [ $? -ne 0 ]; then + echo "Error: Failed to switch to commit: $ref_value" + popd >/dev/null + return 1 + fi + echo "Successfully cloned and switched to commit '$ref_value' for repository: $git_url" + else + echo "Error: Invalid ref_type. Use 'branch' or 'commit'." + popd >/dev/null + return 1 + fi + + # 使用 popd 退出目录 + popd >/dev/null + return 0 +} + +download_and_extract() { + local url="$1" # 下载链接 + local target_dir="$2" # 目标目录 + local temp_dir="$3" # 临时目录 + local tag_name="$4" # 标签 + + temp_dir=$temp_dir/$tag_name + + # 确保目标目录和临时目录存在 + mkdir -p "$target_dir" + mkdir -p "$temp_dir" + + # 检查目标目录是否为空 + if [ -z "$(ls -A "$target_dir")" ]; then + echo "目标目录 $target_dir 为空,开始下载并解压..." + + # 下载文件到临时目录 + local temp_file="$temp_dir/$(basename "$url")" + + if [[ ! -f $temp_file ]]; then + wget -O "$temp_file" "$url" + else + echo "文件 $temp_file 已存在,跳过下载." + fi + + # 检查下载是否成功 + if [ $? -eq 0 ]; then + # 解压到临时目录 + unzip_dir=$temp_dir/$(date +"%Y_%m_%d") + mkdir -p $unzip_dir + tar -xz -C "$unzip_dir" -f "$temp_file" + echo "解压到:$unzip_dir" + + cp -r $unzip_dir/* $target_dir + + # 检查解压后的内容 + # local extracted_dir + # extracted_dir=$(ls -d "$temp_dir"/*/ | head -n 1) + # if [ -d "$extracted_dir" ]; then + # if [ -d "$target_dir" ]; then + # rm -rf $target_dir + # fi + # # 移动解压后的目录到目标目录 + # mv "$extracted_dir" "$target_dir" + # echo "下载并解压完成,$extracted_dir 目录已移动到 $target_dir。" + # else + # echo "解压后没有找到目录,无法移动。" + # exit 1 + # fi + else + echo "下载失败,退出脚本。" + exit 1 + fi + else + echo "目标目录 $target_dir 不为空,跳过下载操作。" + fi +} + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +if [ -z "${WORKSPACE+x}" ]; then + WORKSPACE=$(realpath "$project_dir/..") +fi + +build_tx8_deps=OFF + +if [ $# -gt 0 ]; then + if [[ "${1,,}" == "build" ]]; then + build_tx8_deps=ON + fi +fi + +if [ "x$build_tx8_deps" == "xON" ]; then + download_dir=$WORKSPACE/download + tx8fw_dir=$download_dir/triton-tx8fw + # download_and_extract "http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx81fw_2025-0606_bbe682.tar.gz" \ + # "$tx8fw_dir" "$WORKSPACE/download" "tx8fw" + download_and_extract "http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx81fw_2025-0617_acd719.tar.gz" \ + "$tx8fw_dir" "$download_dir" "tx8fw" + + host_runtime_dir=$download_dir/host_runtime + download_and_extract "http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx8-host/master/host_runtime_v5.2.0_daily_2025-0605_7a6768.tar.gz" \ + "$host_runtime_dir" "$download_dir" "runtime" + + xuantie_sdk_dir=$download_dir/tx8fw-xuantie-sdk + clone_and_checkout "git@gitlab.tsingmicro.com:tx8_developers/tx8fw-xuantie-sdk.git" \ + "$xuantie_sdk_dir" "branch" "master" + + kcore_fw_bin=$tx8fw_dir/bin/FW/kcore_fw.bin + if [ ! -f $kcore_fw_bin ]; then + echo "error can't find:$kcore_fw_bin" + fi + instr_tx81_lib=$tx8fw_dir/lib/libinstr_tx81.a + if [ ! -f $instr_tx81_lib ]; then + echo "error can't find:$instr_tx81_lib" + fi + instr_tx81_inc=$tx8fw_dir/include/instr_tx81/include + if [ ! -d $instr_tx81_inc ]; then + echo "error can't find:$instr_tx81_inc" + fi + # instr_tx81_lib=$WORKSPACE/download/libinstr_tx81.a + xuantie_dir=$xuantie_sdk_dir/Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2 + if [ ! -d $xuantie_dir ]; then + echo "error can't find:$xuantie_dir" + fi + + tx8_depends_dir=$WORKSPACE/tx8_deps + if [ -d $tx8_depends_dir ]; then + rm -rf $tx8_depends_dir + fi + mkdir $tx8_depends_dir + pushd $tx8_depends_dir + cp -r $xuantie_dir ./ + cp -r $host_runtime_dir/**/* ./ + + if [ ! -d chip_out ]; then + mkdir lib + fi + cp $kcore_fw_bin chip_out + + if [ ! -d lib ]; then + mkdir lib + fi + cp $instr_tx81_lib lib + cp $instr_tx81_inc/* include + + # 非必须 + lib_log_h=$tx8fw_dir/include/components/oplib_tx81/riscv/riscv/include/lib_log.h + echo "lib_log_h:$lib_log_h" + if [ -f $lib_log_h ]; then + cp $lib_log_h ./include + fi + popd + + pushd $WORKSPACE + current_time=$(date +%Y%m%d_%H%M%S) + pkg_file=download/tx8_depends_$current_time.tar.gz + if [ ! -d download ]; then + mkdir download + fi + if [ -f $pkg_file ]; then + rm -f $pkg_file + fi + tar -zcvf $pkg_file tx8_deps + popd +else + echo abc + # tx8_deps_base=$WORKSPACE/tx8_deps + # # clone_and_checkout "git@gitlab.tsingmicro.com:triton-based-projects/llvm-project.git" "$WORKSPACE/llvm-project-for-ztc" "branch" "ztc" + # clone_and_checkout "git@gitlab.tsingmicro.com:triton-based-projects/llvm-project.git" "$WORKSPACE/llvm-project" "commit" "a66376b0dc3b2ea8a84fda26faca287980986f78" + + # download_and_extract "http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx8/triton/tx8_depends_20250512_145415.tar.gz" \ + # "$tx8_deps_base" "$WORKSPACE/download" + # clone_and_checkout "ssh://192.168.100.107:29418/tx8_toolchain/tx8be-oplib" "third_party/tx8be-oplib" "commit" "b5651a734f1a6a8943765c83bee1e80d6a2c6a37" +fi diff --git a/third_party/tsingmicro/scripts/install.sh b/third_party/tsingmicro/scripts/install.sh index b0d3346b4..00a51b449 100755 --- a/third_party/tsingmicro/scripts/install.sh +++ b/third_party/tsingmicro/scripts/install.sh @@ -1,9 +1,33 @@ #!/bin/bash +PROXY=http://192.168.100.225:8889 +setup_proxy() { + # Downloading python requirement is needed. + export https_proxy=$PROXY http_proxy=$PROXY all_proxy=$PROXY + export HTTPS_PROXY=$PROXY HTTP_PROXY=$PROXY ALL_PROXY=$PROXY +} + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +use_venv=OFF +if [ $# -gt 0 ]; then + if [[ "${1,,}" == "venv" ]]; then + use_venv=ON + fi +fi + +if [ "x$use_venv" == "xON" ]; then + python3 -m venv $project_dir/.venv --prompt flagtree + source $project_dir/.venv/bin/activate +fi + +setup_proxy + apt install git apt install lld -pip uninstall triton +pip3 install -r $project_dir/third_party/tsingmicro/requirements.txt -pip install gitpython -pip install torch==2.7.0 torchvision +pip3 install -r $project_dir/python/requirements.txt diff --git a/third_party/tsingmicro/scripts/run_tsingmicro.sh b/third_party/tsingmicro/scripts/run_tsingmicro.sh index 13e3ed38c..ddb648076 100755 --- a/third_party/tsingmicro/scripts/run_tsingmicro.sh +++ b/third_party/tsingmicro/scripts/run_tsingmicro.sh @@ -8,17 +8,17 @@ if [ -z "${WORKSPACE+x}" ]; then WORKSPACE=$(realpath "$project_dir/..") fi -TX8_HOME=$WORKSPACE/tx8_deps +TX8_DEPS_ROOT=$WORKSPACE/tx8_deps LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 -if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then - WORKSPACE="${HOME}/.flagtree/tsingmicro/" - TX8_HOME=$WORKSPACE/tx8_deps +if [ ! -d $TX8_DEPS_ROOT ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.triton/tsingmicro/" + TX8_DEPS_ROOT=$WORKSPACE/tx8_deps LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 fi -if [ ! -d $TX8_HOME ]; then - echo "Error: $TX8_HOME not exist!" 1>&2 +if [ ! -d $TX8_DEPS_ROOT ]; then + echo "Error: $TX8_DEPS_ROOT not exist!" 1>&2 exit 1 fi @@ -27,16 +27,26 @@ if [ ! -d $LLVM ]; then exit 1 fi -export TX8_HOME=$TX8_HOME +if [ -f $project_dir/.venv/bin/activate ]; then + source $project_dir/.venv/bin/activate +fi + +export TX8_DEPS_ROOT=$TX8_DEPS_ROOT export LLVM_SYSPATH=$LLVM -export LD_LIBRARY_PATH=$TX8_HOME/lib:$LD_LIBRARY_PATH -export TRITON_ALWAYS_COMPILE=1 +export PYTHONPATH=$LLVM/python_packages/mlir_core:$PYTHONPATH + +export LD_LIBRARY_PATH=$TX8_DEPS_ROOT/lib:$LD_LIBRARY_PATH +export VENDOR_VERSION=1 # export TRITON_DUMP_PATH=$project_dir/dump +export TRITON_ALWAYS_COMPILE=1 -echo "export TX8_HOME=$TX8_HOME" +echo "export TX8_DEPS_ROOT=$TX8_DEPS_ROOT" echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +echo "export PYTHONPATH=$PYTHONPATH" echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH" +echo "export VENDOR_VERSION=$VENDOR_VERSION" +# echo "export TRITON_DUMP_PATH=$TRITON_DUMP_PATH" echo "export TRITON_ALWAYS_COMPILE=$TRITON_ALWAYS_COMPILE" -python3 $@ +USE_SIM_MODE=${USE_SIM_MODE} python3 $@