From b03ac8682392dafd928541bca0d1575112ce356f Mon Sep 17 00:00:00 2001 From: tsingmicro-public Date: Fri, 9 May 2025 17:50:36 +0800 Subject: [PATCH 01/12] [BACKEND] tsingmicro back init --- third_party/tsingmicro/CMakeLists.txt | 18 + third_party/tsingmicro/backend/compiler.py | 334 +++ third_party/tsingmicro/backend/cpu_driver.py | 389 +++ third_party/tsingmicro/backend/driver.cpp | 624 +++++ third_party/tsingmicro/backend/driver.py | 994 ++++++++ third_party/tsingmicro/crt/CMakeLists.txt | 113 + third_party/tsingmicro/crt/README.md | 2 + .../tsingmicro/crt/gcc_flash_smartl.ld | 244 ++ .../tsingmicro/crt/gcc_flash_xiaohui.ld | 250 ++ third_party/tsingmicro/crt/gcc_tx8_smarth.ld | 279 +++ .../crt/include/Tx81/instr_adapter.h | 61 + .../crt/include/Tx81/instr_adapter_plat.h | 569 +++++ .../tsingmicro/crt/include/Tx81/instr_def.h | 924 +++++++ .../crt/include/Tx81/runtime/hrt_common.h | 474 ++++ .../crt/include/Tx81/runtime/hrt_interface.h | 98 + .../tsingmicro/crt/include/Tx81/tx81.h | 22 + third_party/tsingmicro/crt/lib/Tx81/argmax.c | 26 + third_party/tsingmicro/crt/lib/Tx81/argmin.c | 26 + third_party/tsingmicro/crt/lib/Tx81/arith.c | 165 ++ .../tsingmicro/crt/lib/Tx81/bf16_fp16.c | 25 + .../tsingmicro/crt/lib/Tx81/bf16_fp32.c | 25 + .../tsingmicro/crt/lib/Tx81/bf16_int16.c | 26 + .../tsingmicro/crt/lib/Tx81/bf16_int32.c | 26 + .../tsingmicro/crt/lib/Tx81/bf16_int8.c | 25 + .../tsingmicro/crt/lib/Tx81/bf16_tf32.c | 25 + .../tsingmicro/crt/lib/Tx81/bilinear.c | 32 + third_party/tsingmicro/crt/lib/Tx81/bit2fp.c | 26 + third_party/tsingmicro/crt/lib/Tx81/common.c | 25 + third_party/tsingmicro/crt/lib/Tx81/concat.c | 34 + third_party/tsingmicro/crt/lib/Tx81/conv.c | 61 + third_party/tsingmicro/crt/lib/Tx81/cos.c | 26 + third_party/tsingmicro/crt/lib/Tx81/count.c | 28 + third_party/tsingmicro/crt/lib/Tx81/exp.c | 26 + third_party/tsingmicro/crt/lib/Tx81/explp.c | 26 + .../tsingmicro/crt/lib/Tx81/fp16_bf16.c | 26 + .../tsingmicro/crt/lib/Tx81/fp16_fp32.c | 25 + .../tsingmicro/crt/lib/Tx81/fp16_int16.c | 26 + .../tsingmicro/crt/lib/Tx81/fp16_int32.c | 26 + .../tsingmicro/crt/lib/Tx81/fp16_int8.c | 26 + .../tsingmicro/crt/lib/Tx81/fp16_tf32.c | 25 + .../tsingmicro/crt/lib/Tx81/fp32_bf16.c | 26 + .../tsingmicro/crt/lib/Tx81/fp32_fp16.c | 26 + .../tsingmicro/crt/lib/Tx81/fp32_int16.c | 26 + .../tsingmicro/crt/lib/Tx81/fp32_int32.c | 26 + .../tsingmicro/crt/lib/Tx81/fp32_int8.c | 26 + .../tsingmicro/crt/lib/Tx81/fp32_tf32.c | 26 + .../tsingmicro/crt/lib/Tx81/gatherscatter.c | 33 + third_party/tsingmicro/crt/lib/Tx81/gemm.c | 47 + third_party/tsingmicro/crt/lib/Tx81/img2col.c | 36 + .../tsingmicro/crt/lib/Tx81/int16_bf16.c | 26 + .../tsingmicro/crt/lib/Tx81/int16_fp16.c | 25 + .../tsingmicro/crt/lib/Tx81/int16_fp32.c | 26 + .../tsingmicro/crt/lib/Tx81/int16_tf32.c | 26 + .../tsingmicro/crt/lib/Tx81/int32_bf16.c | 26 + .../tsingmicro/crt/lib/Tx81/int32_fp16.c | 27 + .../tsingmicro/crt/lib/Tx81/int32_fp32.c | 26 + .../tsingmicro/crt/lib/Tx81/int32_tf32.c | 26 + .../tsingmicro/crt/lib/Tx81/int8_bf16.c | 25 + .../tsingmicro/crt/lib/Tx81/int8_fp16.c | 25 + .../tsingmicro/crt/lib/Tx81/int8_fp32.c | 25 + .../tsingmicro/crt/lib/Tx81/int8_tf32.c | 25 + .../tsingmicro/crt/lib/Tx81/leakyrelu.c | 27 + third_party/tsingmicro/crt/lib/Tx81/ln.c | 26 + third_party/tsingmicro/crt/lib/Tx81/log2.c | 26 + third_party/tsingmicro/crt/lib/Tx81/lut16.c | 27 + third_party/tsingmicro/crt/lib/Tx81/lut32.c | 27 + .../tsingmicro/crt/lib/Tx81/mask_move.c | 25 + third_party/tsingmicro/crt/lib/Tx81/memset.c | 29 + third_party/tsingmicro/crt/lib/Tx81/mirror.c | 31 + .../tsingmicro/crt/lib/Tx81/nchw2nhwc.c | 31 + .../tsingmicro/crt/lib/Tx81/nhwc2nchw.c | 31 + third_party/tsingmicro/crt/lib/Tx81/pad.c | 33 + third_party/tsingmicro/crt/lib/Tx81/pow2.c | 26 + third_party/tsingmicro/crt/lib/Tx81/randgen.c | 28 + third_party/tsingmicro/crt/lib/Tx81/rdma.c | 41 + third_party/tsingmicro/crt/lib/Tx81/reduce.c | 103 + third_party/tsingmicro/crt/lib/Tx81/relu.c | 26 + .../tsingmicro/crt/lib/Tx81/rotate180.c | 31 + .../tsingmicro/crt/lib/Tx81/rotate270.c | 31 + .../tsingmicro/crt/lib/Tx81/rotate90.c | 30 + third_party/tsingmicro/crt/lib/Tx81/satrelu.c | 26 + third_party/tsingmicro/crt/lib/Tx81/sigmoid.c | 26 + third_party/tsingmicro/crt/lib/Tx81/sin.c | 26 + .../tsingmicro/crt/lib/Tx81/softplus.c | 27 + third_party/tsingmicro/crt/lib/Tx81/tanh.c | 26 + .../tsingmicro/crt/lib/Tx81/tensornorm.c | 31 + .../tsingmicro/crt/lib/Tx81/tf32_bf16.c | 26 + .../tsingmicro/crt/lib/Tx81/tf32_fp16.c | 25 + .../tsingmicro/crt/lib/Tx81/tf32_fp32.c | 25 + .../tsingmicro/crt/lib/Tx81/tf32_int16.c | 26 + .../tsingmicro/crt/lib/Tx81/tf32_int32.c | 26 + .../tsingmicro/crt/lib/Tx81/tf32_int8.c | 26 + .../tsingmicro/crt/lib/Tx81/transpose.c | 31 + third_party/tsingmicro/crt/lib/Tx81/wdma.c | 41 + third_party/tsingmicro/include/CMakeLists.txt | 6 + .../include/ExecutionEngine/CRunnerUtils.cpp | 192 ++ .../include/ExecutionEngine/CRunnerUtils.h | 499 ++++ .../tsingmicro/include/ExecutionEngine/Msan.h | 35 + .../include/ExecutionEngine/version.txt | 1 + .../include/magic-kernel-func/CMakeLists.txt | 2 + .../magic-kernel-func/Dialect/CMakeLists.txt | 1 + .../Dialect/IR/MagicKernelFuncOps.td | 19 + .../include/magic-kernel-instr/CMakeLists.txt | 2 + .../magic-kernel-instr/Dialect/CMakeLists.txt | 1 + .../Dialect/IR/MagicKernelInstrOps.td | 13 + .../include/magic-kernel/CMakeLists.txt | 2 + .../magic-kernel/Conversion/CMakeLists.txt | 2 + .../CoreDialectsToMK/CMakeLists.txt | 10 + .../CoreDialectsToMK/CoreDialectsToMK.h | 27 + .../Conversion/CoreDialectsToMK/Passes.h | 26 + .../Conversion/CoreDialectsToMK/Passes.td | 18 + .../Conversion/LinalgToMK/CMakeLists.txt | 3 + .../Conversion/LinalgToMK/LinalgToMK.h | 36 + .../Conversion/LinalgToMK/Passes.h | 22 + .../Conversion/LinalgToMK/Passes.td | 19 + .../magic-kernel/Dialect/CMakeLists.txt | 1 + .../magic-kernel/Dialect/IR/CMakeLists.txt | 11 + .../Dialect/IR/MagicKernelAttrDefs.td | 15 + .../Dialect/IR/MagicKernelDialect.h | 34 + .../Dialect/IR/MagicKernelDialect.td | 44 + .../magic-kernel/Dialect/IR/MagicKernelOps.td | 284 +++ .../Dialect/IR/MagicKernelTypes.td | 102 + .../Transforms/BufferizableOpInterfaceImpl.h | 26 + .../triton-shared/Analysis/MaskAnalysis.h | 163 ++ .../Analysis/OpFoldResultUtils.h | 67 + .../triton-shared/Analysis/PtrAnalysis.h | 271 +++ .../triton-shared/Analysis/UseAnalysis.h | 119 + .../AnalysisStructured/PtrAnalysis.h | 274 +++ .../include/triton-shared/CMakeLists.txt | 2 + .../triton-shared/Conversion/CMakeLists.txt | 5 + .../StructuredToMemref/CMakeLists.txt | 3 + .../Conversion/StructuredToMemref/Passes.h | 15 + .../Conversion/StructuredToMemref/Passes.td | 10 + .../StructuredToMemref/StructuredToMemref.h | 24 + .../TritonArithToLinalg/CMakeLists.txt | 3 + .../TritonArithToLinalg/ConversionPatterns.h | 2119 +++++++++++++++++ .../Conversion/TritonArithToLinalg/Passes.h | 15 + .../Conversion/TritonArithToLinalg/Passes.td | 20 + .../TritonArithToLinalg/TritonArithToLinalg.h | 28 + .../TritonToCoreDialects/CMakeLists.txt | 9 + .../Conversion/TritonToCoreDialects/Passes.h | 22 + .../Conversion/TritonToCoreDialects/Passes.td | 18 + .../TritonToCoreDialects.h | 27 + .../Conversion/TritonToLinalg/CMakeLists.txt | 9 + .../Conversion/TritonToLinalg/Passes.h | 22 + .../Conversion/TritonToLinalg/Passes.td | 18 + .../TritonToLinalg/TritonToLinalg.h | 33 + .../TritonToStructured/CMakeLists.txt | 3 + .../Conversion/TritonToStructured/Passes.h | 15 + .../Conversion/TritonToStructured/Passes.td | 19 + .../TritonToStructured/TritonToStructured.h | 17 + .../triton-shared/Dialect/CMakeLists.txt | 2 + .../Dialect/TritonStructured/CMakeLists.txt | 1 + .../TritonStructured/IR/CMakeLists.txt | 8 + .../IR/TritonStructuredDialect.h | 27 + .../IR/TritonStructuredDialect.td | 213 ++ .../Dialect/TritonTilingExt/CMakeLists.txt | 1 + .../Dialect/TritonTilingExt/IR/CMakeLists.txt | 11 + .../IR/TritonTilingExtDialect.h | 107 + .../IR/TritonTilingExtInterfaces.td | 102 + .../TritonTilingExt/IR/TritonTilingExtOps.td | 242 ++ .../include/tsingmicro-tx81/CMakeLists.txt | 2 + .../tsingmicro-tx81/Conversion/CMakeLists.txt | 3 + .../Conversion/MKToTx81/CMakeLists.txt | 3 + .../Conversion/MKToTx81/MKToTx81.h | 37 + .../Conversion/MKToTx81/Passes.h | 22 + .../Conversion/MKToTx81/Passes.td | 18 + .../Tx81MemrefToLLVM/CMakeLists.txt | 3 + .../Conversion/Tx81MemrefToLLVM/Passes.h | 22 + .../Conversion/Tx81MemrefToLLVM/Passes.td | 19 + .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.h | 40 + .../Conversion/Tx81ToLLVM/CMakeLists.txt | 3 + .../Tx81ToLLVM/KernelArgBufferPass.h | 31 + .../Tx81ToLLVM/KernelArgBufferPass.td | 32 + .../Conversion/Tx81ToLLVM/Passes.h | 22 + .../Conversion/Tx81ToLLVM/Passes.td | 38 + .../Conversion/Tx81ToLLVM/Tx81ToLLVM.h | 33 + .../tsingmicro-tx81/Dialect/CMakeLists.txt | 1 + .../tsingmicro-tx81/Dialect/IR/CMakeLists.txt | 14 + .../Dialect/IR/Tx81AttrDefs.td | 24 + .../tsingmicro-tx81/Dialect/IR/Tx81Dialect.h | 33 + .../tsingmicro-tx81/Dialect/IR/Tx81Dialect.td | 43 + .../tsingmicro-tx81/Dialect/IR/Tx81Ops.h | 26 + .../tsingmicro-tx81/Dialect/IR/Tx81Ops.td | 775 ++++++ .../tsingmicro-tx81/Dialect/IR/Tx81Types.td | 107 + .../tsingmicro/lib/Analysis/CMakeLists.txt | 14 + .../tsingmicro/lib/Analysis/MaskAnalysis.cpp | 559 +++++ .../lib/Analysis/OpFoldResultUtils.cpp | 292 +++ .../tsingmicro/lib/Analysis/PtrAnalysis.cpp | 1375 +++++++++++ .../tsingmicro/lib/Analysis/UseAnalysis.cpp | 220 ++ .../lib/AnalysisStructured/CMakeLists.txt | 13 + .../lib/AnalysisStructured/PtrAnalysis.cpp | 1395 +++++++++++ third_party/tsingmicro/lib/CMakeLists.txt | 4 + .../tsingmicro/lib/Conversion/CMakeLists.txt | 10 + .../CoreDialectsToMK/CMakeLists.txt | 23 + .../CoreDialectsToMK/CoreDialectsToMKPass.cpp | 60 + .../lib/Conversion/LinalgToMK/CMakeLists.txt | 19 + .../lib/Conversion/LinalgToMK/LinalgToMK.cpp | 58 + .../Conversion/LinalgToMK/LinalgToMKPass.cpp | 72 + .../lib/Conversion/MKToTx81/CMakeLists.txt | 19 + .../lib/Conversion/MKToTx81/MKToTx81.cpp | 642 +++++ .../lib/Conversion/MKToTx81/MKToTx81Pass.cpp | 75 + .../StructuredToMemref/CMakeLists.txt | 22 + .../StructuredToMemref/StructuredToMemref.cpp | 859 +++++++ .../StructuredToMemrefPass.cpp | 416 ++++ .../TritonArithToLinalg/CMakeLists.txt | 22 + .../TritonArithToLinalg.cpp | 96 + .../TritonArithToLinalgPass.cpp | 227 ++ .../TritonToCoreDialects/CMakeLists.txt | 31 + .../TritonToCoreDialectsPass.cpp | 73 + .../Conversion/TritonToLinalg/CMakeLists.txt | 28 + .../TritonToLinalg/TritonToLinalg.cpp | 94 + .../TritonToLinalg/TritonToLinalgPass.cpp | 229 ++ .../TritonToStructured/CMakeLists.txt | 22 + .../TritonToStructuredPass.cpp | 344 +++ .../Tx81MemrefToLLVM/CMakeLists.txt | 17 + .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp | 335 +++ .../Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp | 88 + .../lib/Conversion/Tx81ToLLVM/CMakeLists.txt | 23 + .../Tx81ToLLVM/KernelArgBufferPass.cpp | 208 ++ .../lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp | 1326 +++++++++++ .../Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp | 80 + .../tsingmicro/lib/Dialect/CMakeLists.txt | 4 + .../lib/Dialect/MagicKernel/CMakeLists.txt | 10 + .../MagicKernel/IR/MagicKernelDialect.cpp | 33 + .../BufferizableOpInterfaceImpl.cpp | 122 + .../Dialect/TritonStructured/CMakeLists.txt | 1 + .../TritonStructured/IR/CMakeLists.txt | 11 + .../IR/TritonStructuredDialect.cpp | 22 + .../IR/TritonStructuredOps.cpp | 179 ++ .../Dialect/TritonTilingExt/CMakeLists.txt | 1 + .../IR/BufferizableOpInterfaceImpl.cpp | 134 ++ .../Dialect/TritonTilingExt/IR/CMakeLists.txt | 17 + .../lib/Dialect/TritonTilingExt/IR/CumSum.cpp | 112 + .../IR/TritonTilingExtDialect.cpp | 404 ++++ .../lib/Dialect/TsingMicroTx81/CMakeLists.txt | 10 + .../Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp | 30 + .../lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp | 10 + third_party/tsingmicro/name.conf | 1 + .../tsingmicro/python/triton_tsingmicro.cc | 7 + 240 files changed, 25165 insertions(+) create mode 100644 third_party/tsingmicro/CMakeLists.txt create mode 100644 third_party/tsingmicro/backend/compiler.py create mode 100644 third_party/tsingmicro/backend/cpu_driver.py create mode 100644 third_party/tsingmicro/backend/driver.cpp create mode 100644 third_party/tsingmicro/backend/driver.py create mode 100644 third_party/tsingmicro/crt/CMakeLists.txt create mode 100644 third_party/tsingmicro/crt/README.md create mode 100644 third_party/tsingmicro/crt/gcc_flash_smartl.ld create mode 100644 third_party/tsingmicro/crt/gcc_flash_xiaohui.ld create mode 100644 third_party/tsingmicro/crt/gcc_tx8_smarth.ld create mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_adapter.h create mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h create mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_def.h create mode 100644 third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h create mode 100644 third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h create mode 100644 third_party/tsingmicro/crt/include/Tx81/tx81.h create mode 100644 third_party/tsingmicro/crt/lib/Tx81/argmax.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/argmin.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/arith.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bilinear.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/bit2fp.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/common.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/concat.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/conv.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/cos.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/count.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/exp.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/explp.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/gemm.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/img2col.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/ln.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/log2.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/lut16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/lut32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/mask_move.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/memset.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/mirror.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/pad.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/pow2.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/randgen.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/rdma.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/reduce.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/relu.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/rotate180.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/rotate270.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/rotate90.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/satrelu.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/sigmoid.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/sin.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/softplus.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tanh.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tensornorm.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/transpose.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/wdma.c create mode 100644 third_party/tsingmicro/include/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp create mode 100644 third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h create mode 100644 third_party/tsingmicro/include/ExecutionEngine/Msan.h create mode 100644 third_party/tsingmicro/include/ExecutionEngine/version.txt create mode 100644 third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td create mode 100644 third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td create mode 100644 third_party/tsingmicro/include/magic-kernel/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h create mode 100644 third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h create mode 100644 third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h create mode 100644 third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h create mode 100644 third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h create mode 100644 third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h create mode 100644 third_party/tsingmicro/include/triton-shared/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td create mode 100644 third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td create mode 100644 third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td create mode 100644 third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td create mode 100644 third_party/tsingmicro/lib/Analysis/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp create mode 100644 third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp create mode 100644 third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp create mode 100644 third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp create mode 100644 third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp create mode 100644 third_party/tsingmicro/lib/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp create mode 100644 third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp create mode 100644 third_party/tsingmicro/name.conf create mode 100644 third_party/tsingmicro/python/triton_tsingmicro.cc diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt new file mode 100644 index 000000000..6643564da --- /dev/null +++ b/third_party/tsingmicro/CMakeLists.txt @@ -0,0 +1,18 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/crt/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + # FIXME: Unify the libraries for TsingMicro into fewer ones + add_triton_plugin(TritonTsingMicro ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_tsingmicro.cc + LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR + Tx81IR TritonTilingExtIR TritonStructuredIR TritonToCoreDialects + TritonToLinalg TritonToStructured StructuredToMemref LinalgToMagicKernel + TritonArithToLinalg CoreDialectsToMK Tx81ToLLVM Tx81MemrefToLLVM MKToTx81) + target_link_libraries(TritonTsingMicro PRIVATE Python3::Module pybind11::headers) +endif() +#if(TRITON_BUILD_UT) +# add_subdirectory(unittest) +#endif() +#add_subdirectory(test) diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py new file mode 100644 index 000000000..919961eaf --- /dev/null +++ b/third_party/tsingmicro/backend/compiler.py @@ -0,0 +1,334 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from types import ModuleType +import hashlib +import tempfile +import os +import re +import shutil +import subprocess +import functools +from pathlib import Path + +def _get_ztc_opt_path() -> str: + path = os.getenv("ZTC_OPT_PATH", "") + if path == "": + raise Exception("ZTC_OPT_PATH is not set.") + return path + +def _get_vendor_runtime_path() -> str: + path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") + if path == "": + raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") + return path + +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_BINARY_DIR", "") + if path == "": + raise Exception("LLVM_BINARY_DIR is not set.") + return os.path.join(path, bin_name) + +# The riscv c header files and libraries path. +def _get_libc_root() -> str: + path = os.getenv("LIB_C_ROOT", "") + if path == "": + raise Exception("LIB_C_ROOT is not set.") + return path + +def _dump_ir_if_needed(files): + path = os.getenv("ZTC_DUMP_PATH", "") + if not path: + return + + os.makedirs(path, exist_ok=True) + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + + +def _ttir_to_coreir(mod): + # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tt.mlir") + dst_path = os.path.join(tmpdir, "core.mlir") + Path(src_path).write_text(ttir_code) + ztc_opt_path = _get_ztc_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ztc_opt_path, src_path, + "--triton-to-core-dialects", + "--one-shot-bufferize", + #"--mlir-print-debuginfo", + "-o", + dst_path]) + return Path(dst_path).read_text() + + +def _optimize_coreir(coreir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return coreir + + +def _coreir_to_mkir(mod): + # Get core dialects as string + coreir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "core.mlir") + dst_path = os.path.join(tmpdir, "mk.mlir") + Path(src_path).write_text(coreir_code) + ztc_opt_path = _get_ztc_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ztc_opt_path, src_path, + "--core-dialects-to-mk", + #"--mlir-print-debuginfo", + "-o", + dst_path]) + return Path(dst_path).read_text() + + +def _optimize_mkir(mkir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return mkir + + +def _coreir_to_txir(mod): + # Get core dialects as string + coreir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "core.mlir") + dst_path = os.path.join(tmpdir, "tx.mlir") + Path(src_path).write_text(coreir_code) + ztc_opt_path = _get_ztc_opt_path() + _dump_ir_if_needed([src_path]) + subprocess.check_call([ztc_opt_path, src_path, + "--expand-strided-metadata", + "--mk-to-tx81", + "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load + "--cse", # unused memref.subview/memref.reinterpret + #"--mlir-print-debuginfo", + "-o", + dst_path]) + return Path(dst_path).read_text() + +def _optimize_txir(txir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return txir + + +def _txir_to_llir(mod): + txir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tx.mlir") + llvmir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(src_path).write_text(txir_code) + ztc_opt_path = _get_ztc_opt_path() + _dump_ir_if_needed([src_path]) + # Tx81 and core dialects to LLVM-MLIR + subprocess.check_call([ztc_opt_path, src_path, + "--tx81-memref-to-llvm", + "--tx81-to-llvm", + "--convert-scf-to-cf", + "--convert-math-to-llvm", + "--convert-func-to-llvm", + "--convert-cf-to-llvm", + # Use tx81-memref-to-llvm custom pass for now. + # "--finalize-memref-to-llvm", + "--convert-arith-to-llvm", # need exec last since arith.const conversion + # Remove all unrealized casts created + "--reconcile-unrealized-casts", + "--canonicalize", + #"--mlir-print-debuginfo", + "-o", + llvmir_path]) + _dump_ir_if_needed([llvmir_path]) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + subprocess.check_call([mlir_translate_path, llvmir_path, + "--mlir-to-llvmir", + "-o", + llir_path]) + + _dump_ir_if_needed([llir_path]) + return Path(llir_path).read_text() + + +def _mkir_to_llir(mkir: str): + with tempfile.TemporaryDirectory() as tmpdir: + mkir_path = os.path.join(tmpdir, "mk.mlir") + llvmir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(mkir_path).write_text(mkir) + mlir_opt_path = _get_llvm_bin_path("mlir-opt") + # MagicKernel-MLIR to LLVM-MLIR + subprocess.check_call([mlir_opt_path, mkir_path, + "--convert-linalg-to-affine-loops", + # Note: eliminate-empty-tensors fails when there are multiple func.return ops + # in a single kernel which are the results of early returns. + # See python/examples/test_early_return.py for examples. + # We disable this pass for now since performance on CPU isn't the main + # focus at the moment. + # "--eliminate-empty-tensors", + "--empty-tensor-to-alloc-tensor", + "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", + "--convert-linalg-to-loops", + "--expand-strided-metadata", + "--convert-scf-to-cf", + "--convert-arith-to-llvm", + "--convert-math-to-llvm", + "--convert-complex-to-llvm", + "--convert-vector-to-llvm", + "--convert-index-to-llvm", + "--memref-expand", + "--finalize-memref-to-llvm", + "--convert-func-to-llvm", + "--convert-cf-to-llvm", + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + "--lower-affine", + "--convert-arith-to-llvm", + # Remove all unrealized casts created + "--canonicalize", + "--reconcile-unrealized-casts", + "--mlir-print-debuginfo", + "-o", + llvmir_path]) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + subprocess.check_call([mlir_translate_path, llvmir_path, + "--mlir-to-llvmir", + "-o", + llir_path]) + _dump_ir_if_needed([mkir_path, llvmir_path, llir_path]) + return Path(llir_path).read_text() + + +def _optimize_llir(llir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return llir + + +def _llir_to_bin(llir: str, metadata): + pattern = r"define void @(\w+)\(.+" + matches = re.findall(pattern, llir) + assert len(matches) == 1 + metadata["name"] = matches[0] + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, "kernel.so") + dst_path = "/tmp/kernel.o" + Path(src_path).write_text(llir) + clang_path = _get_llvm_bin_path("clang++") + subprocess.check_call([clang_path, src_path, + "-O2", + "-c", + "-fPIC", + "--target=riscv64-unknown-elf", + "-march=rv64imafdc", + "-o", + dst_path]) + + _dump_ir_if_needed([dst_path]) + with open(dst_path, 'rb') as f: + so = f.read() + return so + + + +@dataclass(frozen=True) +class CPUOptions: + debug: bool = False + arch: str = None + num_warps: int = 0 + num_ctas: int = 0 + num_stages: int = 1 + enable_warp_specialization: bool = False + enable_fp_fusion: bool = False + extern_libs = None + cluster_dims: tuple = (1, 1, 1) + shared: bool = False + allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + + def __post_init__(self): + pass + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +class CPUBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'cpu' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.binary_ext = "so" + + def parse_options(self, opts) -> Any: + args = {'arch': self.target.arch} + args.update({k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts}) + return CPUOptions(**args) + + def get_codegen_implementation(self): + codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)} + return codegen_fns + + def pack_metadata(self, metadata): + # Note: We actually don't need any of these except for the name which is + # used in the launch function in driver.py. Putting these in so we're + # consistent with other backends + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + metadata.name + ) + + # Our compilation pipeline isn't in python like nvidia or amd, no need to load + # dialects. See `ztc.cc` + def load_dialects(self, ctx): + return + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["coreir"] = lambda src, metadata: _optimize_coreir(_ttir_to_coreir(src)) + # stages["mkir"] = lambda src, metadata: _optimize_mkir(_coreir_to_mkir(src)) + stages["txir"] = lambda src, metadata: _optimize_txir(_coreir_to_txir(src)) + stages["llir"] = lambda src, metadata: _optimize_llir(_txir_to_llir(src)) + stages["so"] = lambda src, metadata: _llir_to_bin(src, metadata) + + + @functools.lru_cache() + def hash(self): + return self.target + + # The CPU backend does not use any extra python modules, return an empty dictionary + def get_module_map(self) -> Dict[str, ModuleType]: + return {} diff --git a/third_party/tsingmicro/backend/cpu_driver.py b/third_party/tsingmicro/backend/cpu_driver.py new file mode 100644 index 000000000..b52b6363e --- /dev/null +++ b/third_party/tsingmicro/backend/cpu_driver.py @@ -0,0 +1,389 @@ +import hashlib +import tempfile +import sysconfig + +import os, subprocess, tempfile +import importlib.util +import sysconfig + +from pathlib import Path + +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget + +# The riscv compiler +def _get_llvm_bin_path() -> str: + path = os.getenv("LLVM_BINARY_DIR", "") + if path == "": + raise Exception("LLVM_BINARY_DIR is not set.") + return path + +# The riscv c header files and libraries path. +def _get_libc_root() -> str: + path = os.getenv("LIB_C_ROOT", "") + if path == "": + raise Exception("LIB_C_ROOT is not set.") + return path + + +# -------------------- Launcher ---------------------------- +def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + +def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return _ty_to_cpp(ty) + +def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + +def _generate_launcher(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants) + kernel_arg_decls += ', ' if kernel_arg_decls else '' + + kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants) + kernel_parameters += ', ' if kernel_parameters else '' + + return f""" +#include +#include +#include +#include "ExecutionEngine/CRunnerUtils.h" +#include "ExecutionEngine/CRunnerUtils.cpp" + +extern "C" {{ + // Pointer type (=Memref) becomes int64_t + MemRef struct + // FIXME: understand what this int64_t is used for. + void {kernel_name}({kernel_arg_decls} + int, int, int, int, int, int); +}} + +static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{ + if (gridX*gridY*gridZ > 0) {{ + // Cast "function" to the real function type. + for(int x = 0; x < gridX; x++) {{ + for(int y = 0; y < gridY; y++) {{ + for(int z = 0; z < gridZ; z++) {{ + // Use some random type "char" here. + {' '.join(f'StridedMemRefType ptr_arg{i} = {{static_cast(arg{i}), static_cast(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")} + {kernel_name}({kernel_parameters} + gridX, gridY, gridZ, x, y, z); + }} + }} + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // [CPULauncher-specific]: We don't need the metadata below but just put them + // here anyway to be consistent with others. + // This will make updating the driver easier in the future. + + // int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + // if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + // PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + // return NULL; + // }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__ztc_ref_cpu_kernel_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___ztc_ref_cpu_kernel_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + +def compile_module(launcher_src, kernel_placeholder_name): + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + py_lib_dir = sysconfig.get_config_var("LIBDIR") + py_version = sysconfig.get_config_var("LDVERSION") + py_lib = '{name}{py_version}'.format(name="python", py_version=py_version) + cpu_backend_path = Path(__file__).resolve().parent + clang = os.path.join(_get_llvm_bin_path(), "clang++") + libc_inc = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "include") + libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib") + include_dir = os.path.join(cpu_backend_path, "include") + + def launch( + gridX, gridY, gridZ, stream, cu_function, + kernel_metadata, launch_metadata, + launch_enter_hook, launch_exit_hook, *args): + # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. + # Let's compile a kernel every time. + # The cu_function parameter actually contains our assembly source code. + # See CPUUtils.load_binary method. + asm_src = cu_function + kernel_name = kernel_metadata[6] # see pack_metadata in compiler.py + src = launcher_src.replace(kernel_placeholder_name, kernel_name) + + key = hashlib.md5(src.encode("utf-8") + asm_src).hexdigest() + cache = get_cache_manager(key) + name = "__ztc_ref_cpu_kernel_launcher" + filename = f"{name}.so" + cache_path = cache.get_file(filename) + + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_bytes(asm_src) + Path(launcher_src_path).write_text(src) + # Compile it together. + subprocess.check_call([ + clang, "-std=c++17", "--target=riscv64-unknown-elf", + launcher_src_path, asm_src_path, f"-I{libc_inc}", + f"-I{py_include_dir}", f"-I{include_dir}", f"-I{libc_lib}", + f"-L{py_lib_dir}", + "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path + ]) + + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch(gridX, gridY, gridZ, + kernel_metadata, launch_metadata, + launch_enter_hook, launch_exit_hook, + *args) + + return launch + + +class CPULauncher(object): + + def __init__(self, src, metadata): + kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" + + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) + # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name + # in the following launch function. + self.launch = compile_module(launcher_src, kernel_placeholder_name) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + + +class CPUUtils(object): + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + # Note: + # nvidia and amd backends have their corresponding driver.c file that exposes + # get_device_properties and load_binary using python bindings. + # (see third_party/nvidia/backend/driver.c) + # These methods are then used in compiler.py to initialize handles before running + # the triton kernels. + # Since we recompile the kernel every time (see compile_module above), + # and the metadata generated by these functions aren't applicable to the cpu + # backend, just define the same functions with dummy implementation. + @staticmethod + def get_device_properties(device): + return { + "max_shared_mem": 2 ** 20, + "multiprocessor_count": None, + "sm_clock_rate": None, + "mem_clock_rate": None, + "mem_bus_width": None + } + + # Important note: + # Since we cannot easy pass function pointers around, we pass along the + # assembly source code so that compile_module above can recompile the + # module every time. + @staticmethod + def load_binary(name, kernel_asm, shared, device): + return ( + None, # module + kernel_asm, # function + None, # n_regs + None # n_spills + ) + + +class CPUDriver(DriverBase): + + def __init__(self): + super().__init__() + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + self.binary_ext = "cpuasm" + + # CPU driver won't be automatically chosen unless explicitly set through + # triton.runtime.driver.set_active(CPUDriver()) + @staticmethod + def is_active(): + return False + + def get_device_capability(self): + return ("cpu", 0) + + def get_current_stream(self, device): + return None + + def get_current_device(self): + # CPU doesn't have a device to return. Return something. + return "cpu" + + def set_current_device(self, device): + # CPU doesn't have a device to set + assert device == "cpu" + return + + def get_current_target(self): + return GPUTarget("cpu", 0, 0) + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/third_party/tsingmicro/backend/driver.cpp b/third_party/tsingmicro/backend/driver.cpp new file mode 100644 index 000000000..8658ecf86 --- /dev/null +++ b/third_party/tsingmicro/backend/driver.cpp @@ -0,0 +1,624 @@ +//===---------------------------- driver.c --------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Tx81 platform device side runtime interface for python. +// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include + +struct Kernel_Param // Triton kernel arguments +{ + uint32_t gridX; + uint32_t gridY; + uint32_t gridZ; + // TODO... +}; + +struct Kernel_Head +{ + uint32_t param_type; + uint32_t param_num; + uint32_t param_addr; + uint32_t xxxxx; +}; + +// Raises a Python exception and returns false if code is not RET_SUCCESS. +static bool tsmAssert(TSM_RETCODE code, const char *file, int line) { + if (code == RET_SUCCESS) + return true; + + const char *prefix = "Triton Error [TX81]: "; + const char *str; + + // Map error codes to strings + switch(code) { + case RET_ERROR: + str = "General error"; + break; + case RET_PARAM1_ERROR: + case RET_PARAM2_ERROR: + case RET_PARAM3_ERROR: + str = "Parameter error"; + break; + case RET_DEVICE_OFFLINE: + str = "Device offline"; + break; + case RET_DEVICE_NOMEM: + str = "Device out of memory"; + break; + case RET_DEVICE_IN_IDLE: + str = "Device in idle state"; + break; + case RET_DEVICE_IN_ATTACH: + str = "Device already attached"; + break; + case RET_DEVICE_ATTACH_SUCCESS: + str = "Device attach success"; + break; + case RET_DEVICE_ATTACH_READY: + str = "Device attach ready"; + break; + case RET_DEVICE_LOSE_CONNECT: + str = "Device connection lost"; + break; + case RET_ENV_CLEAN_UP: + str = "Environment cleanup required"; + break; + default: + str = "Unknown error"; + } + + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + + +static void prepare_input(std::vector devices, uint32_t dev_index, + std::shared_ptr chip_info) +{ + for (uint32_t i = 0; i < chip_info->input_num; ++i) { + chip_info->input_dev_addr.push_back(0); + if (TsmDeviceMalloc(devices[dev_index], chip_info->input_dev_addr[i], + chip_info->input_size[i]) != RET_SUCCESS) { + printf("[Chip id %u] Input%d, DeviceMalloc failed!\n", devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + } + + if (TsmMemcpyH2D((TsmDevicePtr)chip_info->input_dev_addr[i], + (void*) chip_info->input_host_addr[i], + chip_info->input_size[i]) != RET_SUCCESS) { + printf("[Chip id %u] Input%d, MemcpyH2D failed!\n", devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + } + } +} + +static void prepare_output(std::vector devices, uint32_t dev_index, + std::shared_ptr chip_info) { + for (size_t i = 0; i < chip_info->output_num; ++i) { + chip_info->output_dev_addr.push_back(0); + printf("[Chip id %u] output[%lu] data(size: %lu)\n", + devices[dev_index]->chip_id, i, chip_info->output_size[i]); + + if (TsmDeviceMalloc(devices[dev_index], chip_info->output_dev_addr[i], + chip_info->output_size[i]) != RET_SUCCESS) { + printf("[Chip id %u] output[%lu], DeviceMalloc failed!\n", + devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + } + } +} + +TSM_RETCODE kernel_result_process(std::vector devices, uint32_t dev_index, + std::shared_ptr hostboot, + std::shared_ptr chip_info, + TsmDevicePtr bootpm_dev, std::string case_dir) { + for (size_t i = 0; i < chip_info->output_num; ++i) { + // 动态shape,需要处理真实的output size + if (TsmMemcpyD2H(hostboot->get_bootpmbuffer(), bootpm_dev, + hostboot->get_maxlen()) != RET_SUCCESS) { + return RET_ERROR; + } + + auto out_tensor = hostboot->get_dev_output_tensor_after_run(i); + chip_info->output[i]->dim = out_tensor->dim; + std::memcpy(chip_info->output[i]->shape, out_tensor->shape, sizeof(out_tensor->shape)); + chip_info->output_size[i] = hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); + for (uint32_t j = 0; j < out_tensor->dim; ++j) { + if (out_tensor->shape[j] > 0) { + chip_info->output_size[i] *= out_tensor->shape[j]; + } + } + + TsmHostPtr output_host_addr = (TsmHostPtr)malloc(chip_info->output_size[i]); + if (chip_info->output_size[i] > 0) { + if (TsmMemcpyD2H((void*)output_host_addr, chip_info->output_dev_addr[i], + chip_info->output_size[i]) != RET_SUCCESS) { + return RET_ERROR; + } + } + + printf("[Chip id %u] output_dev_addr=%ld\n", devices[dev_index]->chip_id, + chip_info->output_dev_addr[i]); + + // TODO: Processing output +#if 0 + std::string file_path = case_dir + "/chip" + std::to_string(dev_index) + + "/agent/data/out" + std::to_string(i) + "_riscv.bin"; + saveDataToFile(file_path, output_host_addr, chip_info->output_size[i]); +#endif + + if (output_host_addr != 0) { + free((void *)output_host_addr); + } + } + return RET_SUCCESS; +} + +TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) { + if (bootpm_dev != 0) { + printf("[Chip id %u] bootpm dev addr: 0x%lx \n", chip_id, bootpm_dev); + if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) { + return RET_ERROR; + } + bootpm_dev = 0; + } + + return RET_SUCCESS; +} + +static void setHostBoot(std::shared_ptr &chip_info, + std::shared_ptr &hostboot) { + if (chip_info == nullptr) { + printf("chip_info is null.\n"); + return; + } + + if (hostboot == nullptr) { + printf("hostboot is null.\n"); + return; + } + + for (size_t i = 0; i < chip_info->input_dev_addr.size(); ++i) { + hostboot->set_dev_input(i, chip_info->input_dev_addr[i], chip_info->input_size[i]); + hostboot->set_dev_input_tensor(i, chip_info->input[i]); + } + + for (size_t i = 0; i < chip_info->output_dev_addr.size(); ++i) { + hostboot->set_dev_output(i, chip_info->output_dev_addr[i], chip_info->output_size[i]); + } + + for (size_t i = 0; i < chip_info->param_num; ++i) { + hostboot->set_dev_param(i, chip_info->param_dev_addr[i], chip_info->param_size[i]); + } + + return; +} + + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define TSM_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!tsmAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define TSM_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!tsmAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Global state for Tx81 devices +static std::vector g_tx81_devices; +static bool g_runtime_initialized = false; + +// Initialize the Tx81 runtime if not already initialized +static bool init_tx81_runtime_if_needed() { + if (g_runtime_initialized) { + return true; + } + + // Initialize the Tx81 runtime + if (TsmInitRuntime() != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); + return false; + } + + // Get device count + uint32_t device_num = 0; + if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) { + PyErr_SetString(PyExc_RuntimeError, "Failed to get Tx81 device count or no devices found"); + TsmDeInitRuntime(); + return false; + } + + // Set up devices - for simplicity, we're using a 1x1 configuration + uint32_t first_phy_id = 0; + uint32_t card_x = 1; + uint32_t card_y = 1; + + if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); + TsmDeInitRuntime(); + return false; + } + + g_runtime_initialized = true; + return true; +} + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { +#if 0 + // FIXME: Extracting device_id + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + + // Initialize the runtime if needed + if (!init_tx81_runtime_if_needed()) { + return NULL; + } + + // Check device ID is valid + if (device_id < 0 || (size_t)device_id >= g_tx81_devices.size()) { + PyErr_SetString(PyExc_ValueError, "Invalid device ID"); + return NULL; + } + + // Get device handle + TsmDevice* device = g_tx81_devices[device_id]; + + // Get device information + TsmDeviceInfo info; + memset(&info, 0, sizeof(TsmDeviceInfo)); + TSM_CHECK_AND_RETURN_NULL(TsmGetDeviceInfo(&info)); +#endif + // Extract device properties + // Note: We're mapping Tx81 properties to fields expected by Triton + int max_shared_mem = 1024 * 1024 * 4; // Default 4MB + //int multiprocessor_count = device->tile_num; + int multiprocessor_count = 1; + int sm_clock_rate = 1000; // Placeholder + int mem_clock_rate = 2000; // Placeholder + int mem_bus_width = 256; // Placeholder + +#if 0 + // For the specified device, get more detailed info + if (device_id < (int)info.card_num) { + CardComputeInfo& card_info = info.card_compute_info[device_id]; + multiprocessor_count = card_info.all_tile_num; + } +#endif + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", + "max_shared_mem", max_shared_mem, + "multiprocessor_count", multiprocessor_count, + "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, + "mem_bus_width", mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; +#if 0 + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + + // Initialize the runtime if needed + if (!init_tx81_runtime_if_needed()) { + return NULL; + } + + // Check device ID is valid + if (device < 0 || (size_t)device >= g_tx81_devices.size()) { + PyErr_SetString(PyExc_ValueError, "Invalid device ID"); + return NULL; + } + + TsmDevice* tx81_device = g_tx81_devices[device]; + + // First, we need to write binary data to a temporary file + char temp_path[256]; + sprintf(temp_path, "/tmp/triton_tx81_kernel_XXXXXX"); + int fd = mkstemp(temp_path); + if (fd == -1) { + PyErr_SetString(PyExc_RuntimeError, "Failed to create temporary file"); + return NULL; + } + + // Write the kernel data to the temporary file + if (write(fd, data, data_size) != data_size) { + close(fd); + unlink(temp_path); + PyErr_SetString(PyExc_RuntimeError, + "Failed to write kernel data to temporary file"); + return NULL; + } + close(fd); + + // Create a model structure, the compiled kernel.so is specified via case_dir + // and the name of the entry function is specified via case_name. + TsmModel *model = new TsmModel(); + model->case_dir = std::string(temp_path); + model->case_name = std::string(name); + + // Set compile options + CompileOption compl_option = {}; + compl_option.comp_enable = 0; // Use precompiled kernel.so instead + compl_option.chip_x = 1; + compl_option.chip_y = 1; + compl_option.check_enable = true; + compl_option.enable_kcore_bin = 1; + compl_option.enable_kcore_so = 1; + + std::vector devices = {tx81_device}; + + // Not really compile the kernel, as kernel is already compiled, so this + // runtime API only configs the data structure of device firmware and the + // information of the program and data that runs on it. + Py_BEGIN_ALLOW_THREADS; + TSM_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + TsmCompileMultiGraph(devices, *model, "", compl_option)); + Py_END_ALLOW_THREADS; + + // For Tx81, we'll use a simpler model than CUDA + // We return a pointer to the TsmModel, which is analogous to CUmodule + // For the function pointer, we'll use model_id+0, which will be interpreted + // in the launcher code + // n_regs and n_spills are placeholders for now + int32_t n_regs = 256; // Default/placeholder value + int32_t n_spills = 0; // Default/placeholder value + + // Clean up the temporary file + unlink(temp_path); +#endif + + int32_t n_regs = 256; + int32_t n_spills = 0; + // Return values to Python including module, function, n_regs, n_spills + return Py_BuildValue("(KKii)", "module {}", "void @add_kernel() {}", n_regs, n_spills); +} + + + +static PyObject *launch(PyObject *self, PyObject* args) { + std::vector devices; + // TODO:通过参数传递获取device信息 + + + // 需要的输入信息: devices, case_dir(按固定路径存放的kernelso), input_host_addr/input_size/input_num, + // output_host_addr/output_size/output_num, param信息(如果有权重) + + + TsmModel *new_model = new TsmModel(); // 设备相关参数已在dev中 + std::string option = "-O2"; + CompileOption compl_option = {}; + compl_option.comp_enable = 0; + compl_option.chip_x = 1; //单卡 + compl_option.chip_y = 1; + compl_option.check_enable = true; + compl_option.enable_kcore_bin = 1; + compl_option.enable_kcore_so = 1; + new_model->case_dir = "/tmp/todo"; // 参数传入, kernelso路径,同streambin/kcorebin文件夹路径 + + if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != RET_SUCCESS) { + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { + if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) { + printf("[Chip id %u] tx_engine: tx_reset, failed!\n", dev_index); + } else { + printf("[Chip id %u] tx_engine: tx_reset, success!\n", dev_index); + } + } + printf("TsmCompile failed.\n"); + return NULL; + } + + std::vector kmodel_vec = {new_model}; + + uint32_t input_num = 2; // TODO:根据kernel参数填写 + uint32_t output_num = 1; // TODO:根据kernel参数填写 + uint32_t param_num = 0; // 权重数 + std::shared_ptr hostboot = std::make_shared(input_num, output_num, param_num); + + std::shared_ptr chip_info; + // 填充chipinfo信息 + chip_info->input_num = input_num; + chip_info->output_num = output_num; + chip_info->param_num = param_num; + chip_info->imm_size = 0; //缓存大小暂设置为0,和算子实际相关; + // chip_info->tile_num = 16; // 未使用 + // chip_info->tile_x = 4; // 未使用 + // chip_info->tile_y = 4; // 未使用 + for(uint32_t i = 0; i < chip_info->input_num; ++i) { + chip_info->input_size[i] = 6; // TODO:填写实际输入大小 + chip_info->input_host_addr = std::vector{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; // TODO: 填写实际输入地址 + } + + for(uint32_t i = 0; i < chip_info->output_num; ++i) { + chip_info->output_size[i] = 1; // TODO:填写实际输出大小 + chip_info->output_host_addr = std::vector{0x0}; // TODO: 填写实际输出地址 + } + + //for(uint32_t i = 0; i < chip_info->param_num; ++i) { + // chip_info->param_size[i] = 0; // TODO:填写实际权重大小 + // chip_info->param_host_addr = 0x0; + //} + + // prepare data/ load kernel/run/unload kernel/get out data/release memory + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { + // input prepare + prepare_input(devices, dev_index, chip_info); + // output prepare + prepare_output(devices, dev_index, chip_info); + + uint32_t chip_id = devices[dev_index]->chip_id; + TsmSetMonitorInfo(devices[dev_index]); + + // load kernel + char module_symbol[] = "main_kernel"; + TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); + printf("TsmLoadKernel finish!...\n"); + + printf("[Chip id %u] Set boot-params...\n", chip_id); + size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); + TsmDevicePtr dev_dyn_mods_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != RET_SUCCESS) { + return NULL; + } + TsmDevicePtr dev_tlv_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_tlv_ptr, sizeof(DynTLV_DynMods)) != RET_SUCCESS) { + return NULL; + } + + TsmDevicePtr dev_kernel_head_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_kernel_head_ptr, sizeof(Kernel_Head)) != RET_SUCCESS) { + return NULL; + } + TsmDevicePtr dev_kernel_param_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_kernel_param_ptr, sizeof(Kernel_Param)) != RET_SUCCESS) { + return NULL; + } + + Kernel_Head *host_kernel_head_ptr = (Kernel_Head*)malloc(sizeof(Kernel_Head)); + Kernel_Param *host_kernel_param_ptr = (Kernel_Param*)malloc(sizeof(Kernel_Param)); + + host_kernel_head_ptr->param_type = 1; + host_kernel_head_ptr->param_num = 1; // Number of kernel arguments + host_kernel_head_ptr->param_addr = dev_kernel_param_ptr; // 将kernel 使用的参数地址赋值 + + // TODO: Setup the triton kernel arguments + host_kernel_param_ptr->gridX = 512; + host_kernel_param_ptr->gridY = 512; + host_kernel_param_ptr->gridZ = 512; + + TsmMemcpyH2D(dev_kernel_head_ptr, host_kernel_head_ptr, sizeof(Kernel_Head)); + TsmMemcpyH2D(dev_kernel_param_ptr, host_kernel_param_ptr, sizeof(Kernel_Param)); + + free(host_kernel_head_ptr); + free(host_kernel_param_ptr); + + // TODO: No such API + setHostBoot(chip_info, hostboot); + set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, dev_tlv_ptr, dev_kernel_head_ptr); + + TsmDevicePtr bootpm_dev; + if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, hostboot->get_maxlen()) != RET_SUCCESS) { + return NULL; + } + if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), hostboot->get_maxlen()) != RET_SUCCESS) { + return NULL; + } + + if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) { + printf("TsmRun bootpm_dev failed.\n"); + return NULL; + } + + // 卸载kernel + TsmUnloadKernel(devices[dev_index], kmodel_vec); + + // 得到输出数据,并进行处理 + printf("[Chip id %u] Copy output from device...\n", chip_id); + if (kernel_result_process(devices, dev_index, hostboot, chip_info, bootpm_dev, new_model->case_dir) != RET_SUCCESS) { + printf("free dev memory failed.\n"); + return NULL; + } + if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) { + printf("free dev memory failed.\n"); + return NULL; + } + //释放多图相关tlv + if (TsmDeviceFree(dev_kernel_head_ptr) != RET_SUCCESS) { + printf("free dev_kernel_head_ptr failed.\n"); + return NULL; + } + if (TsmDeviceFree(dev_kernel_param_ptr) != RET_SUCCESS) { + printf("free dev_kernel_param_ptr failed.\n"); + return NULL; + } + + if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) { + printf("free dev_dyn_mods_ptr failed.\n"); + return NULL; + } + if (TsmDeviceFree(dev_tlv_ptr) != RET_SUCCESS) { + printf("free dev_tlv_ptr failed.\n"); + return NULL; + } + + printf("[dev_index %u] Set Terminal Info...\n", dev_index); + if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) { + printf("TsmSetTerminate failed.\n"); + return NULL; + } +#if 0 + if (freeTensorData(chip_id, chip_info) != RET_SUCCESS) { + printf("free tensor data dev memory failed.\n"); + } +#endif + } + + Py_RETURN_NONE; +} + + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided binary into Tx81 driver"}, + {"launch", launch, METH_VARARGS, "tx8 launch kernel!"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given Tx81 device"}, + + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, + "tx81_utils", + NULL, // documentation + -1, // size + ModuleMethods +}; + +PyMODINIT_FUNC PyInit_tx81_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} \ No newline at end of file diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py new file mode 100644 index 000000000..bafa51d28 --- /dev/null +++ b/third_party/tsingmicro/backend/driver.py @@ -0,0 +1,994 @@ +# +# This file implements the triton kernel driver interfaces where are used in +# triton/python/triton/compiler/compiler.py. +# For how the interface in driver class is used, see the implementation of the +# file above. +# +import hashlib +import tempfile +import os +import subprocess +import importlib.util +import shutil +import sysconfig +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import GPUDriver +from triton.backends.compiler import GPUTarget + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dirs = [os.path.join(dirname, "include"), + os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), + os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include")] +library_dirs = [os.path.join(dirname, "lib"), + os.path.join(sysconfig.get_path('platlib'), "torch", "lib")] +libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] + +# Path configuration for cross compilation +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_BINARY_DIR", "") + if path == "": + raise Exception("LLVM_BINARY_DIR is not set.") + return os.path.join(path, bin_name) + +def _get_libc_root() -> str: + path = os.getenv("LIB_C_ROOT", "") + if path == "": + raise Exception("LIB_C_ROOT is not set.") + return path + +def _get_vendor_runtime_path() -> str: + path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") + if path == "": + raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") + return path + +def _dump_ir_if_needed(files): + path = os.getenv("ZTC_DUMP_PATH", "") + if not path: + return + + os.makedirs(path, exist_ok=True) + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + +# Build a native ELF on the platform running this python script +def compile_native(src, name): + fname = "native_" + name + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{fname}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, f"{name}.cpp") + with open(src_path, "w") as f: + f.write(src) + _dump_ir_if_needed([src_path]) + so = _build(name, src_path, tmpdir, library_dirs, include_dirs, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{fname}.so", binary=True) + _dump_ir_if_needed([cache_path]) + + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + +# Build a accelerator controller ELF +def compile_accelerator(src, name, ext): + name = "npu_" + name + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + libc_inc = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "include") + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, f"{name}.{ext}") + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, "wrapper.so") + dst_path = "/tmp/wrapper.o" + with open(src_path, "w") as f: + f.write(src) + _dump_ir_if_needed([src_path]) + clang_path = _get_llvm_bin_path("clang") + # Compile + subprocess.check_call([clang_path, src_path, + "-O2", + "-c", + "-fPIC", + f"-I{libc_inc}", + "--target=riscv64-unknown-elf", + "-march=rv64imafdc", + "-o", + dst_path]) + + with tempfile.TemporaryDirectory() as tmpdir: + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, f"{name}.so") + dst_path = "/tmp/kernel.so" + libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib", "rv64imafdc", "lp64d") + libcrt_lib = os.path.join(_get_libc_root(), "lib", "gcc", "riscv64-unknown-elf", "15.0.0", "rv64imafdc", "lp64d") + libvr_path = _get_vendor_runtime_path() + clang_path = _get_llvm_bin_path("clang") + # Link wrapper, kernel with Tx81 crt and intrinsics(libkcorert.a) + subprocess.check_call([clang_path, + "-nostdlib", + # FIXME: Hardcoded path + "/tmp/wrapper.o", + "/tmp/kernel.o", + "-O2", + "--target=riscv64-unknown-elf", + "-march=rv64imafdc", + "-fPIC", + # "-shared", # ELF toolchain doesn't support -shared + f"-L{libvr_path}", + f"-L{libc_lib}", + f"-L{libcrt_lib}", + # Allow libkcorert symbol overwrite libc symbols, libkcorert + # should be specified before libc + "-Wl,--allow-multiple-definition", + "-lvr", # Wrapper API of Tx81 intrinsic + "-lkcorert", # Tx81 intrinsic API + "-lc", + "-lm", + "-lgcc", + "-T", + f"{libvr_path}/gcc_tx8_smarth.ld", + "-o", + dst_path]) + + _dump_ir_if_needed([dst_path]) + with open(dst_path, 'rb') as f: + so = f.read() + return so + +# -------------------- Launcher ---------------------------- +def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + +def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return _ty_to_cpp(ty) + +def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + +# This function makes a single kernel invoker which wraps all the input args into +# a single input buffer. +def make_kernel_wrapper_v2(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + return f""" +#include +#include + +// Triton kernel forward declaration, the last 6 arguments are: gridXYZ and xyz +// Using -convert-func-to-llvm=use-bare-ptr-memref-call-conv=true. +void {kernel_name}({arg_decls}, int, int, int, int, int, int); + +// Kernel entry point +// NOTE: Assuming the triton kernel can only take 2 kind of arguments: +// 1. 8 bytes scalar +// 2. Tensor buffer (8 bytes memory address) +// +// The input buffer has the following format: +// +--------------------------------------------------------------------------+ +// | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 8 bytes | +// | gridX | gridY | gridZ | x | y | z | karg1 | +// +--------------------------------------------------------------------------+ +// | 8 bytes | ... | 8 bytes | +// | karg2 | ... | kargn | +// +-------------------------------+ +void __{kernel_name}(void *args) {{ + void* basePtr = args; + + // Extract the kernel arguments from kernel buffer + int gridX = *((int*)basePtr); + int gridY = *((int*)basePtr+1); + int gridZ = *((int*)basePtr+2); + int x = *((int*)basePtr+3); + int y = *((int*)basePtr+4); + int z = *((int*)basePtr+5); + void* krnArgOffsets = (void*) ((int*)basePtr + 6); + + if (gridX*gridY*gridZ <= 0) + return; + + // Invoke the actual kernel. + {kernel_name}({', '.join([f"(void*) (((uint64_t*)krnArgOffsets)[{i}])" + if ty[0] == "*" else + f"*({_ty_to_cpp(ty)}*)(((uint64_t*)krnArgOffsets)[{i}])" + for i, ty in signature.items()])}, + gridX, gridY, gridZ, x, y, z); +}} +""" + + +def make_kernel_wrapper(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + return f""" +#include +#include +#include + +// Tx81 target framwork related definition +typedef struct BootParamHead +{{ + uint32_t MaxLen; + uint32_t LdmemLen; + uint32_t InputNum; + uint32_t OutputNum; + uint32_t ParamNum; + uint32_t reserved; + uint64_t CacheMemLen; + uint64_t CacheMemAttr; + uint32_t Datalen; + uint32_t reserved1; + uint64_t DataAddr; +}} D_BootParamHead; + +// Tx81 target framwork related definition +typedef struct BootParamDyninfo +{{ + uint64_t addr; // device + uint64_t size; + uint32_t dtype; + uint32_t dim; + uint32_t shape[6]; +}} D_BootParamDyninfo; + +// Triton kernel forward declaration, the last 6 arguments are: gridXYZ and xyz +void {kernel_name}({arg_decls}, int, int, int, int, int, int); + +// Get the entry point of kernel arg buffer +void* getKernelArgBuffer(void *args) {{ + // Always use the first BootParam to carry the address points to kernel + // arguments buffer + D_BootParamHead *head = (D_BootParamHead *)args; + assert(head->InputNum == 1); + // Decode the first parameter from BootParam as the kernel buffer info. + D_BootParamDyninfo* kernelBuffer = (D_BootParamDyninfo *)((char *)args + + sizeof(D_BootParamHead)); + // Kernel buffer address on device DDR + return (void*) kernelBuffer->addr; +}} + +// Kernel wrapper +void task(void *krnArgBuf, void *krnArgOffsets, + int gridX, int gridY, int gridZ, int x, int y, int z) {{ + + // Invoke the actual kernel by passing in the triton kernel arguments stored + // on device DDR and the other arguments which generated by compiler. + {kernel_name}({', '.join([f"(void*) (krnArgBuf + ((uint64_t*)krnArgOffsets)[{i}])" + if ty[0] == "*" else + f"*({_ty_to_cpp(ty)}*)(krnArgBuf + ((uint64_t*)krnArgOffsets)[{i}])" + for i, ty in signature.items()])}, + gridX, gridY, gridZ, x, y, z); +}} + +// Kernel entry point, name is aligned that specified to TsmLoadKernel +void __kernel_entry(void *args) {{ + void* basePtr = getKernelArgBuffer(args); + + // Extract the kernel arguments from kernel buffer + int krnArgCount = *(int*)basePtr; + int gridX = *((int*)basePtr+1); + int gridY = *((int*)basePtr+2); + int gridZ = *((int*)basePtr+3); + void* krnArgOffsets = (void*) ((int*)basePtr + 4); + void* krnArgBuf = krnArgOffsets + krnArgCount * sizeof(uint64_t*); + + if (gridX*gridY*gridZ <= 0) + return; + + // Cast "function" to the real function type. + for(int x = 0; x < gridX; x++) {{ + for(int y = 0; y < gridY; y++) {{ + for(int z = 0; z < gridZ; z++) {{ + task (krnArgBuf, krnArgOffsets, gridX, gridY, gridZ, x, y, z); + }} + }} + }} +}} +""" + +def make_launcher(constants, signature, kernel_name): + # Basic declarations + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOOOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # Parameters to pass to the kernel function + kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" for i, ty in signature.items() if i not in constants) + kernel_parameters += ', ' if kernel_parameters else '' + + return f""" +#include +#include +#include +#include +#include +#include +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include +#include +#include +#include +#include +#include "hrt_interface.h" +#include "hrt_common.h" + +// The design of kernel argument buffer: +// The offset starts from the whole kernel buffer +// +------------------------------------------------------------------------+ +// | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 8 bytes | 8 bytes | +// | No. kargs | gridX | gridY | gridZ | karg1 offset | karg2 offset | +// +------------------------------------------------------------------------+ +// .......................... Metadata buffer................................ +// +// +------------------------------------------------------------------------+ +// | ... | 8 bytes | n bytes | n bytes | | n bytes | +// | ... | kargn offset | karg1 data | karg2 data | ... | kargn data | +// +------------------------------------------------------------------------+ +// ^ ^ ^ +// karg1 offset karg2 offset kargn offset +// ... Metadata buffer... | ............ kernel arg buffer .................. + + +// A kernel argument +struct KernelArg {{ + // The actual kernel argument: tensor or scalar + union Data {{ + void* ptr; // Pointer to the tensor data + uint64_t scalar; // Scalar data + }} data; + size_t size; // The size of the kernel argument + + KernelArg(void *ptr, size_t s) : size(s) {{ + data.ptr = ptr; + }} + + KernelArg(uint64_t v, size_t s) : size(s) {{ + data.scalar = v; + }} +}}; + + +extern "C" {{ + // The kernel arguments includes: + // 1. The actual kernel argument in arg_decls + // 2. The group size: gridX, gridY, gridZ + // 3 The thread id in each direction: x, y, z + void {kernel_name}({arg_decls}, int, int, int, int, int, int); +}} + +// Global device vector +static std::vector g_tx81_devices; +static bool g_runtime_initialized = false; + +// Initialize Tx81 runtime +bool init_tx81_runtime() {{ + if (g_runtime_initialized) {{ + return true; // Already initialized + }} + + // Initialize the Tx81 runtime + if (TsmInitRuntime() != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); + return false; + }} + + // Get device count + uint32_t device_num = 0; + if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to get Tx81 device count or no devices found"); + TsmDeInitRuntime(); + return false; + }} + + // Set up devices - for simplicity, we're using a 1x1 configuration + uint32_t first_phy_id = 0; + uint32_t card_x = 1; + uint32_t card_y = 1; + + if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); + TsmDeInitRuntime(); + return false; + }} + + // Initialize all devices + for (auto* dev : g_tx81_devices) {{ + if (TsmInitDevice(dev) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 device"); + TsmDeInitRuntime(); + return false; + }} + }} + + g_runtime_initialized = true; + return true; +}} + +// Clean up Tx81 runtime resources +void cleanup_tx81_runtime() {{ + if (!g_runtime_initialized) {{ + return; + }} + + for (auto* dev : g_tx81_devices) {{ + // Reset and release each device + TsmResetDevice(dev); + TsmReleaseDevice(dev); + }} + + g_tx81_devices.clear(); + TsmDeInitRuntime(); + g_runtime_initialized = false; +}} + + +static void prepare_input(std::vector devices, uint32_t dev_index, + std::shared_ptr chip_info) {{ + for (uint32_t i = 0; i < chip_info->input_num; ++i) {{ + chip_info->input_dev_addr.push_back(0); + if (TsmDeviceMalloc(devices[dev_index], chip_info->input_dev_addr[i], + chip_info->input_size[i]) != RET_SUCCESS) {{ + printf("[Chip id %u] Input%d, DeviceMalloc failed!\\n", devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + }} + + if (TsmMemcpyH2D((TsmDevicePtr)chip_info->input_dev_addr[i], + (void*) chip_info->input_host_addr[i], + chip_info->input_size[i]) != RET_SUCCESS) {{ + printf("[Chip id %u] Input%d, MemcpyH2D failed!\\n", devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + }} + }} +}} + +static void prepare_output(std::vector devices, uint32_t dev_index, + std::shared_ptr chip_info) {{ + for (size_t i = 0; i < chip_info->output_num; ++i) {{ + chip_info->output_dev_addr.push_back(0); + printf("[Chip id %u] output[%lu] data(size: %lu)\\n", + devices[dev_index]->chip_id, i, chip_info->output_size[i]); + + if (TsmDeviceMalloc(devices[dev_index], chip_info->output_dev_addr[i], + chip_info->output_size[i]) != RET_SUCCESS) {{ + printf("[Chip id %u] output[%lu], DeviceMalloc failed!\\n", + devices[dev_index]->chip_id, i); + TsmResetDevice(devices[dev_index]); + return; + }} + }} +}} + +TSM_RETCODE kernel_result_process(std::vector devices, uint32_t dev_index, + std::shared_ptr hostboot, + std::shared_ptr chip_info, + TsmDevicePtr bootpm_dev, std::string case_dir) {{ + for (size_t i = 0; i < chip_info->output_num; ++i) {{ + // 动态shape, 需要处理真实的output size + if (TsmMemcpyD2H(hostboot->get_bootpmbuffer(), bootpm_dev, + hostboot->get_maxlen()) != RET_SUCCESS) {{ + return RET_ERROR; + }} + + auto out_tensor = hostboot->get_dev_output_tensor_after_run(i); + chip_info->output[i]->dim = out_tensor->dim; + std::memcpy(chip_info->output[i]->shape, out_tensor->shape, sizeof(out_tensor->shape)); + chip_info->output_size[i] = hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); + for (uint32_t j = 0; j < out_tensor->dim; ++j) {{ + if (out_tensor->shape[j] > 0) {{ + chip_info->output_size[i] *= out_tensor->shape[j]; + }} + }} + + TsmHostPtr output_host_addr = (TsmHostPtr)malloc(chip_info->output_size[i]); + if (chip_info->output_size[i] > 0) {{ + if (TsmMemcpyD2H((void*)output_host_addr, chip_info->output_dev_addr[i], + chip_info->output_size[i]) != RET_SUCCESS) {{ + return RET_ERROR; + }} + }} + + printf("[Chip id %u] output_dev_addr=%ld\\n", devices[dev_index]->chip_id, + chip_info->output_dev_addr[i]); + + // TODO: Processing output +#if 0 + std::string file_path = case_dir + "/chip" + std::to_string(dev_index) + + "/agent/data/out" + std::to_string(i) + "_riscv.bin"; + saveDataToFile(file_path, output_host_addr, chip_info->output_size[i]); +#endif + + if (output_host_addr != 0) {{ + free((void *)output_host_addr); + }} + }} + return RET_SUCCESS; +}} + +TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) {{ + if (bootpm_dev != 0) {{ + printf("[Chip id %u] bootpm dev addr: 0x%lx \\n", chip_id, bootpm_dev); + if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) {{ + return RET_ERROR; + }} + bootpm_dev = 0; + }} + + return RET_SUCCESS; +}} + +static void setHostBoot(std::shared_ptr &chip_info, + std::shared_ptr &hostboot) {{ + if (chip_info == nullptr) {{ + printf("chip_info is null.\\n"); + return; + }} + + if (hostboot == nullptr) {{ + printf("hostboot is null.\\n"); + return; + }} + + for (size_t i = 0; i < chip_info->input_dev_addr.size(); ++i) {{ + hostboot->set_dev_input(i, chip_info->input_dev_addr[i], chip_info->input_size[i]); + hostboot->set_dev_input_tensor(i, chip_info->input[i]); + }} + + for (size_t i = 0; i < chip_info->output_dev_addr.size(); ++i) {{ + hostboot->set_dev_output(i, chip_info->output_dev_addr[i], chip_info->output_size[i]); + }} + + for (size_t i = 0; i < chip_info->param_num; ++i) {{ + hostboot->set_dev_param(i, chip_info->param_dev_addr[i], chip_info->param_size[i]); + }} + + return; +}} + + +static void _launch(int gridX, int gridY, int gridZ, std::vector &kargs) {{ + std::vector devices; + + if (gridX*gridY*gridZ <= 0) {{ + return; // No work to do + }} + + TsmModel *new_model = new TsmModel(); + + // Create a vector of models + std::vector kmodel_vec = {{new_model}}; + std::string option = "-O2"; + CompileOption compl_option = {{}}; + compl_option.comp_enable = 0; // Use prebuilt binary + compl_option.chip_x = 1; //单卡 + compl_option.chip_y = 1; + compl_option.check_enable = true; + compl_option.enable_kcore_bin = 1; + compl_option.enable_kcore_so = 1; + // FIXME: Hardcoded path + new_model->case_dir = "/tmp/kernel.so"; + + printf("====> Calling TsmCompileMultiGraph\\n"); +#if 0 + if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != RET_SUCCESS) {{ + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ + if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) {{ + printf("[Chip id %u] tx_engine: tx_reset, failed!\\n", dev_index); + }} else {{ + printf("[Chip id %u] tx_engine: tx_reset, success!\\n", dev_index); + }} + }} + printf("TsmCompile failed.\\n"); + return; + }} +#endif + // Calculate the total size of kernel arguments buffer + uint64_t kernel_buffer_size = 0; + for (auto karg : kargs) + kernel_buffer_size += karg.size; + + // Calcuate The kernel argument buffer header size + // 4 bytes header + n * kernel argument metadata + 3 * sizeof(gridXYZ) + uint64_t kernel_meta_buf_size = sizeof(uint64_t*) * kargs.size() + 4 + 12; + kernel_buffer_size += kernel_meta_buf_size; + + // We use input_num = 1 to set the whole kernel arguments buffer as a single + // input + uint32_t input_num = 1; + uint32_t output_num = 0; + uint32_t param_num = 0; + + // Create boot parameter + std::shared_ptr hostboot = std::make_shared(input_num, output_num, param_num); + + // Create chip common info + std::shared_ptr chip_info = std::make_shared(); + chip_info->input_num = input_num; + chip_info->output_num = output_num; + chip_info->param_num = param_num; + chip_info->imm_size = 0; // Cache size + + // Prepare input/output sizes and addresses + chip_info->input_size.resize(input_num); + chip_info->input_host_addr.resize(input_num); + chip_info->input_dev_addr.resize(input_num); + chip_info->output_size.resize(output_num); + chip_info->output_host_addr.resize(output_num); + chip_info->output_dev_addr.resize(output_num); + + // Prepare whole kernel buffer info + chip_info->input.push_back(std::make_shared()); + chip_info->input[0]->dim = 1; + chip_info->input[0]->dtype = FMT_FP32; // Default to float + chip_info->input[0]->shape[0] = 1; // Default shape + chip_info->input_size[0] = kernel_buffer_size; + chip_info->input_host_addr = std::vector{{(uint64_t) 0x0}}; + + // prepare data/ load kernel/run/unload kernel/get out data/release memory + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ + // input prepare + prepare_input(devices, dev_index, chip_info); + // output prepare + prepare_output(devices, dev_index, chip_info); + + uint32_t chip_id = devices[dev_index]->chip_id; + TsmSetMonitorInfo(devices[dev_index]); + + // load kernel + char module_symbol[] = "__kernel_entry"; + TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); + printf("TsmLoadKernel finish!...\\n"); + + printf("[Chip id %u] Set boot-params...\\n", chip_id); + size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); + TsmDevicePtr dev_dyn_mods_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != RET_SUCCESS) + return; + + // Allocate the device memory for all kernel arguments + TsmDevicePtr dev_kernel_buffer; + if (TsmDeviceMalloc(devices[dev_index], dev_kernel_buffer, kernel_buffer_size) != RET_SUCCESS) + return; + + // Kernel meta data and argument buffer + int dev_karg_ptr = dev_kernel_buffer + kernel_meta_buf_size; + + // Kernel arguments address + uint64_t arg_metadata[kargs.size()]; + + // Copy kernel arguments to device DDR (immediately after the metadata) + int i = 0; + uint64_t offset = 0; + for (auto karg : kargs) {{ + if (TsmMemcpyH2D(dev_karg_ptr, karg.data.ptr, karg.size) != RET_SUCCESS) + return; + + // Calculate the offset of each kernel arg's buffer + arg_metadata[i++] = offset; + + // Shift the offset and pointer for next kernel argument. + offset += karg.size; + dev_karg_ptr += karg.size; + }} + + // Create the metadata buffer + uint32_t* metadata = (uint32_t*) malloc(kernel_meta_buf_size); + metadata[0] = (int) kargs.size(); + metadata[1] = gridX; + metadata[2] = gridY; + metadata[3] = gridZ; + memcpy(metadata+20, arg_metadata, kernel_meta_buf_size - 16); + + // Copy kernel metadata to device DDR + if (TsmMemcpyH2D(dev_kernel_buffer, metadata, kernel_meta_buf_size) != RET_SUCCESS) + return; + + setHostBoot(chip_info, hostboot); + set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, 0, dev_kernel_buffer); + + TsmDevicePtr bootpm_dev; + if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, hostboot->get_maxlen()) != RET_SUCCESS) + return; + + if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), hostboot->get_maxlen()) != RET_SUCCESS) + return; + + if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) {{ + printf("TsmRun bootpm_dev failed.\\n"); + return; + }} + + TsmUnloadKernel(devices[dev_index], kmodel_vec); + + // Process kernel output data + printf("[Chip id %u] Copy output from device...\\n", chip_id); + if (kernel_result_process(devices, dev_index, hostboot, chip_info, bootpm_dev, new_model->case_dir) != RET_SUCCESS) {{ + printf("free dev memory failed.\\n"); + return; + }} + + if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) {{ + printf("free dev memory failed.\\n"); + return; + }} + + if (TsmDeviceFree(dev_kernel_buffer) != RET_SUCCESS) {{ + printf("free dev_kernel_param_ptr failed.\\n"); + return; + }} + + if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) {{ + printf("free dev_dyn_mods_ptr failed.\\n"); + return; + }} + + printf("[dev_index %u] Set Terminal Info...\\n", dev_index); + if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) {{ + printf("TsmSetTerminate failed.\\n"); + return; + }} + }} + + // Clean up the model + delete new_model; +}} + +// Structure to represent a device pointer +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static size_t getTensorStorageSize(PyObject* tensor_obj) {{ + const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); + return tensor.storage().nbytes(); +}} + +// Extract tensor raw ptr +static void* extractTensor(PyObject* tensor_obj) {{ + const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); + torch::Tensor contiguous_tensor = tensor.contiguous(); + return contiguous_tensor.data_ptr(); +}} + +// Python module launch function +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + // FIXME: Extra 2 args: + PyObject *dummy1 = NULL; + PyObject *dummy2 = NULL; + // Define the actual kernel arguments + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + + // Init kernel arguments from python side + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook, + &dummy1, &dummy2{args_list})) {{ + return NULL; + }} + +#if 0 // FIXME: kernel_metadata is not correctly inited + // Extract metadata for consistency with other drivers + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, + &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // Call the enter hook if provided + if (launch_enter_hook != Py_None) {{ + PyObject* hook_args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, hook_args); + Py_DECREF(hook_args); + if (!ret) + return NULL; + }} +#endif + + // Construct a data kernel arguments list data structure + std::vector kargs; + {' '.join([f"kargs.emplace_back(extractTensor(_arg{i}), getTensorStorageSize(_arg{i}));" + if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" + for i, ty in signature.items()])} + + // Launch the kernel + _launch(gridX, gridY, gridZ, kargs); + if (PyErr_Occurred()) {{ + return NULL; + }} + + // Call the exit hook if provided + if (launch_exit_hook != Py_None) {{ + PyObject* hook_args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, hook_args); + Py_DECREF(hook_args); + if (!ret) + return NULL; + }} + + // Return None to Python + Py_INCREF(Py_None); + return Py_None; +}} + +// Python module method definitions +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +// Python module definition +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, // documentation + -1, // size + ModuleMethods +}}; + +static PyMethodDef cleanup_method = {{ + "cleanup_tx81_runtime", + (PyCFunction)cleanup_tx81_runtime, + METH_NOARGS, + "Cleanup Tx81 runtime resources" +}}; + +// Python module initialization function +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) {{ + return NULL; + }} + + PyModule_AddFunctions(m, ModuleMethods); + +#if 0 + // Initialize Tx81 runtime during module import + if (!init_tx81_runtime()) {{ + Py_DECREF(m); + return NULL; + }} + + // Register an atexit handler to cleanup Tx81 runtime + PyObject* atexit_module = PyImport_ImportModule("atexit"); + if (atexit_module) {{ + PyObject* cleanup_func = PyCFunction_New(&cleanup_method, NULL); + if (cleanup_func) {{ + PyObject* result = PyObject_CallMethod(atexit_module, "register", "O", cleanup_func); + Py_XDECREF(result); + Py_DECREF(cleanup_func); + }} + Py_DECREF(atexit_module); + }} +#endif + + return m; +}} +""" + +class CrossUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CrossUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + src = Path(os.path.join(dirname, "driver.cpp")).read_text() + mod = compile_native(src, "tx81_utils") + # NOTE: The triton compiler.py framework requires these 2 interface. + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + +# Launch cross compiled runtime program on controller +class CrossLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + + # Compiler kernel wrapper source code + # NOTE: Replace this make_kernel_wrapper to v2 version by if you want + # to call the triton kernel with single input buffer and with a '__' + # prefixed name. + wrapper_src = make_kernel_wrapper(constants, signature, src.fn.__name__) + krn = compile_accelerator(wrapper_src, src.fn.__name__, "c") + + # Compiler runtime kernel launcher source code + launcher_src = make_launcher(constants, signature, src.fn.__name__) + mod = compile_native(launcher_src, "__triton_launcher") + self.launch = mod.launch + + + def __call__(self, *args, **kwargs): + # args: 0: gridX, 1: gridY, 2: gridZ, + # 3: kernel_metadata?, 4: launch_metadata?, + # 5: a tuple(0, 0, False, 1, 1, 1, 'add_kernel'), # this is probably kernel metadata + # 6: None, 7: None, 8: None, + # 9~N: Actual triton kernel args. + self.launch(*args, **kwargs) + + +class CrossDriver(GPUDriver): + + def __init__(self): + super().__init__() + self.utils = CrossUtils() + self.launcher_cls = CrossLauncher + # Needs to overwrite GPUDriver base methods + self.get_current_device = self.get_npu_device + self.set_current_device = self.set_npu_device + self.get_current_stream = self.get_npu_stream + + @staticmethod + def is_active(): + return True + + def get_npu_device(self): + return "cpu" + + def set_npu_device(self, device): + # CPU doesn't have a device to set + assert device == "cpu" + return + + def get_npu_stream(self, device): + return None + + def get_current_target(self): + capability = 1 + warp_size = 16 + return GPUTarget("cpu", capability, warp_size) diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt new file mode 100644 index 000000000..9470c08db --- /dev/null +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -0,0 +1,113 @@ +cmake_minimum_required(VERSION 3.18) + +# Set TARGET from environment variable +if(NOT DEFINED TARGET) + if(DEFINED ENV{CRT_TARGET}) + set(TARGET $ENV{CRT_TARGET}) + else() + message(FATAL_ERROR "CRT_TARGET environment variable is not defined") + endif() +endif() + +if(NOT DEFINED LIB_C_ROOT) + if(DEFINED ENV{LIB_C_ROOT}) + set(LIB_C_ROOT $ENV{LIB_C_ROOT}) + else() + message(FATAL_ERROR "LIB_C_ROOT environment variable is not defined") + endif() +endif() + +# Set LLVM_SYSPATH from environment variable +if(NOT DEFINED LLVM_SYSPATH) + if(DEFINED ENV{LLVM_SYSPATH}) + set(LLVM_SYSPATH $ENV{LLVM_SYSPATH}) + else() + message(FATAL_ERROR "LLVM_SYSPATH environment variable is not defined") + endif() +endif() + +# Project name and version +project(VendorRuntime LANGUAGES CXX C) + +# Define RISC-V target triple +set(RISCV_TRIPLE "riscv64-unknown-elf") +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv) +set(CMAKE_C_COMPILER ${LLVM_SYSPATH}/bin/clang) +set(CMAKE_CXX_COMPILER ${LLVM_SYSPATH}/bin/clang++) + +# Define standard include directories +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/${TARGET}) +include_directories(${LIB_C_ROOT}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +# Set build type default +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type (default Release)" FORCE) +endif() + +# Library name: vr stands for Vendor Runtime +set(VENDOR_RUNTIME_LIB vr) + +# Collect all source files from the vendor directory +file(GLOB_RECURSE VENDOR_SOURCES lib/${TARGET}/*.c) + +# Define RISC-V specific compile options +set(RISCV_COMPILE_OPTIONS + --target=${RISCV_TRIPLE} + -march=rv64gc + -mabi=lp64d + -mcmodel=medany +) + +# Add the library target +add_library(${VENDOR_RUNTIME_LIB} STATIC ${VENDOR_SOURCES}) + +# Apply RISC-V specific settings to our target +target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${RISCV_COMPILE_OPTIONS}) +target_link_options(${VENDOR_RUNTIME_LIB} PRIVATE --target=${RISCV_TRIPLE}) + +# Set properties for the library +set_target_properties(${VENDOR_RUNTIME_LIB} PROPERTIES + POSITION_INDEPENDENT_CODE ON + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib +) + +# Setup compiler and environment for RISC-V compilation +if(CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # Use the existing Clang installation with target triple + message(STATUS "Using Clang with RISC-V target triple") +else() + # Override compiler paths if using explicit RISC-V toolchain + message(STATUS "Setting explicit RISC-V compiler from LLVM_SYSPATH") + + foreach(source ${VENDOR_SOURCES}) + if(source MATCHES "\\.(c)$") + set_source_files_properties(${source} PROPERTIES + COMPILE_FLAGS "-xc --target=${RISCV_TRIPLE}" + LANGUAGE C) + elseif(source MATCHES "\\.(cpp)$") + set_source_files_properties(${source} PROPERTIES + COMPILE_FLAGS "-xc++ --target=${RISCV_TRIPLE}" + LANGUAGE CXX) + endif() + endforeach() + + # Set compiler launch commands for the target + add_custom_command(TARGET ${VENDOR_RUNTIME_LIB} PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Building ${VENDOR_RUNTIME_LIB} for RISC-V target" + ) +endif() + +# Install targets +install(TARGETS ${VENDOR_RUNTIME_LIB} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin +) + +# Install headers (optional) +file(GLOB_RECURSE VENDOR_HEADERS Target/lib/${TARGET}/*.h) +install(FILES ${VENDOR_HEADERS} DESTINATION include/${TARGET}) +install(FILES Target/lib/${TARGET}/libkcorert.a DESTINATION lib/${TARGET}) \ No newline at end of file diff --git a/third_party/tsingmicro/crt/README.md b/third_party/tsingmicro/crt/README.md new file mode 100644 index 000000000..3750c3c7c --- /dev/null +++ b/third_party/tsingmicro/crt/README.md @@ -0,0 +1,2 @@ +This folder contains the low level API implementation for various ML +controller or accelerator. diff --git a/third_party/tsingmicro/crt/gcc_flash_smartl.ld b/third_party/tsingmicro/crt/gcc_flash_smartl.ld new file mode 100644 index 000000000..6786fb002 --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_flash_smartl.ld @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +/****************************************************************************** + * @file gcc_csky.ld + * @brief csky linker file + * @version V1.0 + * @date 02. June 2017 + ******************************************************************************/ +MEMORY +{ + ISRAM : ORIGIN = 0x00000000 , LENGTH = 0x20000 /* ISRAM 128KB*/ + DSRAM : ORIGIN = 0x20000000 , LENGTH = 0x80000 /* DSRAM 512KB*/ +} + +__min_heap_size = 0x200; +PROVIDE (__ram_end = 0x20020000); +PROVIDE (__heap_end = __ram_end); + +REGION_ALIAS("REGION_TEXT", ISRAM); +REGION_ALIAS("REGION_RODATA", ISRAM); +REGION_ALIAS("REGION_DATA", DSRAM); +REGION_ALIAS("REGION_BSS", DSRAM); + +ENTRY(Reset_Handler) +SECTIONS +{ + .text : { + . = ALIGN(0x8) ; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN (0x4) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x10) ; + __etext = . ; + } > REGION_TEXT + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata*) + *(.srodata*) + . = ALIGN(0x8) ; + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + . = ALIGN(0x8) ; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8); + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8); + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > REGION_RODATA + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data*) + *(.gnu.linkonce.d*) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > REGION_DATA AT > REGION_RODATA + ._ram_code : { + . = ALIGN(0x8) ; + __ram_code_start__ = .; + *(.ram.code*) + . = ALIGN(0x8) ; + __ram_code_end__ = .; + } > REGION_DATA AT > REGION_RODATA + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > REGION_BSS AT > REGION_BSS + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = .; + . += __min_heap_size; + . = ALIGN(0x8) ; + } > REGION_BSS AT > REGION_BSS +} diff --git a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld new file mode 100644 index 000000000..70943de73 --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +MEMORY +{ + DRAM : ORIGIN = 0x50000000, LENGTH = 0x100000 /* on-chip DRAM 1*1MB */ +} + +__min_heap_size = 0x200; +PROVIDE (__ram_end = 0x50100000 - 0x8); +PROVIDE (__heap_end = __ram_end); + +REGION_ALIAS("REGION_TEXT", DRAM); +REGION_ALIAS("REGION_RODATA", DRAM); +REGION_ALIAS("REGION_DATA", DRAM); +REGION_ALIAS("REGION_BSS", DRAM); + +ENTRY(Reset_Handler) +SECTIONS +{ + .text : { + . = ALIGN(0x8) ; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text) + *(.text*) + *(.text.*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN(0x8) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x8) ; + __etext = . ; + } > REGION_TEXT + .gcc_except_table : ONLY_IF_RO { + *(.gcc_except_table .gcc_except_table.*) + } > REGION_TEXT + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata) + *(.rodata1) + *(.rodata*) + *(.rodata.*) + *(.rodata.str1.4) + *(.srodata*) + . = ALIGN(0x8) ; + + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8) ; + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8) ; + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > REGION_RODATA + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data) + *(.data*) + *(.data1) + *(.data.*) + *(.gnu.linkonce.d*) + *(.data1) + *(.gcc_except_table) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > REGION_DATA + .gcc_except_table : ONLY_IF_RW { + *(.gcc_except_table .gcc_except_table.*) + __edata = .; + __data_end__ = .; + } > REGION_DATA + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > REGION_BSS + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = .; + . += __min_heap_size; + . = ALIGN(0x8) ; + } > REGION_BSS +} diff --git a/third_party/tsingmicro/crt/gcc_tx8_smarth.ld b/third_party/tsingmicro/crt/gcc_tx8_smarth.ld new file mode 100644 index 000000000..eb2aacb2a --- /dev/null +++ b/third_party/tsingmicro/crt/gcc_tx8_smarth.ld @@ -0,0 +1,279 @@ +/* + * Copyright (C) 2017-2024 Alibaba Group Holding Limited + */ + +/****************************************************************************** + * @file gcc_csky.ld + * @brief csky linker file + * @version V1.0 + * @date 02. June 2017 + ******************************************************************************/ +MEMORY +{ + mem0 (rwx) : ORIGIN = 0x00000000, LENGTH = (20*1024*1024) +} + +REGION_ALIAS("r", mem0); +REGION_ALIAS("w", mem0); +REGION_ALIAS("x", mem0); + +ENTRY(Reset_Handler) +SECTIONS +{ + +.text.startup 0x0:{ + . = ALIGN(0x8) ; + KEEP(*startup.o(*.text.startup)) + *(.text.startup) + } > x + + .text : { + . = ALIGN(0x8) ; + __ram_code_start__ = .; + __stext = . ; + KEEP(*startup.o(*.text)) + KEEP(*startup.o(*.vectors)) + KEEP(*vectors.o(*.text)) + KEEP(*(.text.entry)) + *(.text) + *(.vectors) + *(.text*) + *(.text.*) + *(.gnu.warning) + *(.stub) + *(.gnu.linkonce.t*) + *(.glue_7t) + *(.glue_7) + *(.jcr) + KEEP (*(.init)) + KEEP (*(.fini)) + . = ALIGN(0x8) ; + PROVIDE(__ctbp = .); + *(.call_table_data) + *(.call_table_text) + . = ALIGN(0x8) ; + __etext = . ; + __ram_code_end__ = .; + } > x + + .gcc_except_table : ONLY_IF_RO { + *(.gcc_except_table .gcc_except_table.*) + } > x + + .rodata : { + . = ALIGN(0x8) ; + __srodata = .; + *(.rdata) + *(.rdata*) + *(.rdata1) + *(.rdata.*) + *(.rodata) + *(.rodata1) + *(.rodata*) + *(.rodata.*) + *(.rodata.str1.4) + *(.srodata*) + . = ALIGN(0x8) ; + + __init_array_start = .; + __ctors_start__ = .; + KEEP (*(SORT(.init_array.*))) + KEEP (*(.init_array)) + __init_array_end = .; + __ctors_end__ = .; + + __fini_array_start = .; + __dtors_start__ = .; + KEEP (*(SORT(.fini_array.*))) + KEEP (*(.fini_array)) + __fini_array_end = .; + __dtors_end__ = .; + + __ctor_start__ = .; + KEEP (*(SORT(.ctors.*))) + KEEP (*(.ctors)) + __ctor_end__ = .; + KEEP (*(SORT(.dtors.*))) + KEEP (*(.dtors)) + __dtor_end__ = .; + . = ALIGN(0x8) ; +/*****************************************/ + /* section information for finsh shell */ + . = ALIGN(0x8); + __fsymtab_start = .; + KEEP(*(FSymTab)) + __fsymtab_end = .; + . = ALIGN(0x8); + __vsymtab_start = .; + KEEP(*(VSymTab)) + __vsymtab_end = .; + . = ALIGN(0x8); + + /* section information for initial. */ + __rt_init_start = .; + KEEP(*(SORT(.rti_fn*))) + __rt_init_end = .; + . = ALIGN(0x8) ; + + /* section information for at utest */ + __rt_utest_tc_tab_start = .; + KEEP(*(UtestTcTab)) + __rt_utest_tc_tab_end = .; + . = ALIGN(0x8); + + /* section information for at server */ + . = ALIGN(0x8); + __rtatcmdtab_start = .; + KEEP(*(RtAtCmdTab)) + __rtatcmdtab_end = .; + . = ALIGN(0x8); + + /* section information for modules */ + . = ALIGN(0x8); + __rtmsymtab_start = .; + KEEP(*(RTMSymTab)) + __rtmsymtab_end = .; + + /* section information for uPRC */ + . = ALIGN(0x8); + __uRPCSvcTab_start = .; + KEEP(*(uRPCSvcTab)) + __uRPCSvcTab_end = .; + + /* section information for var export */ + . = ALIGN(0x8); + __ve_table_start = .; + KEEP(*(SORT(*.VarExpTab.*))) + __ve_table_end = .; +/*****************************************/ +/************** added drivers **************/ + _cli_region_begin = .; + KEEP(*(CliRegion)) + . = ALIGN(0x8) ; + _cli_region_end = .; + + __core_driver_start__ = .; + KEEP(*(.core_driver_entry)) + . = ALIGN(0x8) ; + __core_driver_end__ = .; + + __bus_driver_start__ = .; + KEEP(*(*.bus_driver_entry)) + __bus_driver_end__ = .; + + __early_driver_start__ = .; + KEEP(*(*.early_driver_entry)) + __early_driver_end__ = .; + + __vfs_driver_start__ = .; + KEEP(*(*.vfs_driver_entry)) + __vfs_driver_end__ = .; + + __level0_driver_start__ = .; + KEEP(*(*.level0_driver_entry)) + __level0_driver_end__ = .; + + __level1_driver_start__ = .; + KEEP(*(*.level1_driver_entry)) + __level1_driver_end__ = .; + + __level2_driver_start__ = .; + KEEP(*(*.level2_driver_entry)) + __level2_driver_end__ = .; + + __level3_driver_start__ = .; + KEEP(*(*.level3_driver_entry)) + __level3_driver_end__ = .; + + __post_driver_start__ = .; + KEEP(*(*.post_driver_entry)) + __post_driver_end__ = .; +/************** end of drivers *********/ + . = ALIGN(0x8) ; + __erodata = .; + __rodata_end__ = .; + } > r + + .data : { + . = ALIGN(0x8) ; + __sdata = . ; + __data_start__ = . ; + data_start = . ; + *(.got.plt) + *(.got) + *(.gnu.linkonce.r*) + *(.data) + *(.data*) + *(.data1) + *(.data.*) + *(.gnu.linkonce.d*) + *(.data1) + *(.gcc_except_table) + *(.gcc_except_table*) + __start_init_call = .; + *(.initcall.init) + __stop_init_call = .; + __start_cmd = .; + *(.bootloaddata.cmd) + . = ALIGN(0x8) ; + __stop_cmd = .; + __global_pointer$ = .; + *(.sdata) + *(.sdata.*) + *(.sdata2.*) + *(.gnu.linkonce.s.*) + *(__libc_atexit) + *(__libc_subinit) + *(__libc_subfreeres) + *(.note.ABI-tag) + __edata = .; + __data_end__ = .; + . = ALIGN(0x8) ; + } > w AT> r + + .gcc_except_table : ONLY_IF_RW { + *(.gcc_except_table .gcc_except_table.*) + __edata = .; + __data_end__ = .; + } > w AT> r + + .rela.dyn : { + . = ALIGN(0x8) ; + __rel_dyn_start = .; + *(.rela*) + __rel_dyn_end = .; + } + .dynsym : { + . = ALIGN(0x8) ; + __dyn_sym_start = .; + *(.dynsym) + __dyn_sym_end = .; + } + .bss : { + . = ALIGN(0x8) ; + __sbss = ALIGN(0x8) ; + __bss_start__ = . ; + *(.dynsbss) + *(.sbss) + *(.sbss.*) + *(.scommon) + *(.dynbss) + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(0x8) ; + __ebss = . ; + __bss_end__ = .; + __end = . ; + end = . ; + } > w + ._user_heap (NOLOAD): { + . = ALIGN(0x8) ; + *(.stack*) + . = ALIGN(0x8) ; + __heap_start = ABSOLUTE(.); + . = ORIGIN(w) + LENGTH(w); + __heap_end = ABSOLUTE(.); + + } > w +} diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h new file mode 100644 index 000000000..adf93a8e6 --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h @@ -0,0 +1,61 @@ +#ifndef INSTR_ADAPTER_H +#define INSTR_ADAPTER_H +#include +#include +#include +#include + +//#include "common_base.h" +#include "instr_def.h" +#include "instr_adapter_plat.h" + +#ifndef USING_RISCV +#define __CHECK_INSTR__ +#endif +//#define __PLAT_FREERTOS__ +// #define RECORD_INSTR_INVALID +#define SPM_LOWER_BOUND 0 +#define SPM_UPPER_BOUND 0x2EFFFF +#define DDR_LOWER_BOUND 0x280000000 +#define IS_WITHIN_SPM_BOUND(value) (((value) >= SPM_LOWER_BOUND) && ((value) <= SPM_UPPER_BOUND)) +#define IS_WITHIN_DDR_BOUND(value) ((value) >= DDR_LOWER_BOUND) +// 设置 times (0-7 位) +#define TIMES_INVALID_OFFET 0 +// 设置 last_invalid_barrier_id (8-35 位) +#define LAST_INVALID_BARRIER 8 +// 设置 first_invalid_barrier_id (36-63 位) +#define FIRST_INVALID_BARRIER 36 + +typedef struct InstrInvalidInfo { + volatile uint64_t ne_error_info; + volatile uint64_t ct_error_info; + volatile uint64_t td_error_info; + volatile uint64_t rdma_error_info; + volatile uint64_t wdma_error_info; +} InstrInvalidInfo; + +/* + # 0-shape(nhwc) # 1-wshape(Kx,Ky,f,c) # 2-bias # 3-stride(Kx,Ky,Sx,Sy) + # 4-pad(top,bottom,left,right) # 5- dilation(0,0,dilation[0],dilation[1]) +*/ +/*=================================TDMA=================================*/ + +/*=================================RDMA WDMA=================================*/ + +/*=================================Scale=================================*/ + +/*=================================run=================================*/ +uint32_t __execute_ne(TsmNeInstr *instr); +uint32_t __execute_ct(TsmArithInstr *instr); +uint32_t __execute_td(TsmDataMoveInstr *instr); +uint32_t __execute_rdma(TsmRdmaInstr *instr); +uint32_t __execute_wdma(TsmWdmaInstr *instr); +void __execute_sc(SC_Param *instr); +uint64_t TsmExecute(void *instr); + + +/*=================================debug=================================*/ +void set_device_ddr_base(uint64_t base); +uint64_t get_device_ddr_base(); + +#endif /*INSTR_ADAPTER_H*/ diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h new file mode 100644 index 000000000..79f8b0165 --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h @@ -0,0 +1,569 @@ +#ifndef INSTR_ADAPTER_PLAT_H +#define INSTR_ADAPTER_PLAT_H + +// You should define something, according to your device-type + +// ==================== if you run in Tx8-simulator ===================================================== + +#include + + +//#include "oplib_depend_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct Data_Shape { + uint16_t n; + uint16_t h; + uint16_t w; + uint16_t c; +} Data_Shape; + +typedef struct St_Elem_Shape { + uint32_t elem_count; + uint32_t unit_elem_count; + uint32_t full_elem_count; + uint32_t full_unit_elem_count; +} St_Elem_Shape; + +typedef struct St_StrideIteration { + uint32_t stride0; + uint32_t iteration0; + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; +} St_StrideIteration; + +/*=================================C class=================================*/ +typedef struct TsmConv { + void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, Data_Format fmt); + void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, Data_Format fmt); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, Data_Format fmt); + void (*SetOpType)(TsmNeInstr *instr, uint8_t type); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //- negative axis + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //+ positive axis + void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); + void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, Data_Format fmt); + void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); + void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); + void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, uint32_t Sx, uint32_t Sy); + void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, uint8_t zp_cur); + /* data */ +} TsmConv; + +typedef struct TsmDepthwiseConv { + void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, Data_Format fmt); + void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, Data_Format fmt); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, Data_Format fmt); + void (*SetOpType)(TsmNeInstr *instr, uint8_t type); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //- negative axis + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //+ positive axis + void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); + void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); + void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); + void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, uint32_t Sx, uint32_t Sy); + void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, uint8_t zp_cur); + /* data */ +} TsmDepthwiseConv; +typedef struct TsmGemm { + void (*AddInput)(TsmNeInstr *instr, uint64_t L_addr, uint64_t R_addr, Data_Format in_fmt); + void (*ConfigMKN)(TsmNeInstr *instr, uint32_t M, uint32_t K, uint32_t N); + void (*ConfigBatch)(TsmNeInstr *instr, uint32_t Left_batch, uint32_t Right_batch); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Format Out_fmt); + void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, Data_Format fmt); + void (*SetTransflag)(TsmNeInstr *instr, uint8_t L_trans, uint8_t R_trans); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_left, uint8_t zp_right); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t addr); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t addr); + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t addr); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + + /* data */ +} TsmGemm; +typedef struct TsmRdma { + void (*AddSrcDst)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, Data_Format fmt); + void (*ConfigStrideIteration)(TsmRdmaInstr *instr, uint32_t elem_count, uint32_t stride0, uint32_t iteration0, + uint32_t stride1, uint32_t iteration1, uint32_t stride2, uint32_t iteration2); + void (*Rdma1d)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, + uint32_t format); //只有stride0,和iteration0,内层循环, 只复制一次 +} TsmRdma; + +typedef struct TsmWdma { + void (*AddSrcDst)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, Data_Format fmt); + void (*ConfigStrideIteration)(TsmWdmaInstr *instr, uint32_t elem_count, uint32_t stride0, uint32_t iteration0, + uint32_t stride1, uint32_t iteration1, uint32_t stride2, uint32_t iteration2); + void (*Wdma1d)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, + uint32_t format); //只有stride0,和iteration0,内层循环, 只复制一次 +} TsmWdma; + + + +/*=================================CGRA=================================*/ +typedef struct TsmArith { + void(*AbsVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void(*RecipVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void(*SquareVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void(*SqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void(*RsqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void(*NegVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*MaxVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE reserved, Data_Format fmt); + void (*MaxVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, RND_MODE reserved, Data_Format fmt); + void (*MaxVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); + void (*MaxVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE reserved, Data_Format fmt); + void(*MinVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, RND_MODE reserved, Data_Format fmt); + void(*MinVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE reserved, Data_Format fmt); + void(*MinVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); + void(*MinVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE reserved, Data_Format fmt); + void (*AddVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode, Data_Format fmt); + void (*AddVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*AddVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*AddVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); + void (*SubVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode, Data_Format fmt); + void (*SubVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*SubVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*SubVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); + void (*MulVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode, Data_Format fmt); + void (*MulVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*MulVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*MulVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); + void (*DivVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode, Data_Format fmt); + void (*DivVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, + uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*DivVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*DivVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); +} TsmArith; + + +typedef struct TsmRelation { + void (*EqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*EqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*EqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*EqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*UnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolUnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*UnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolUnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*UnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolUnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*UnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolUnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*GreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*GreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*GreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolGreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*GreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolGreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*GreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*GreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*GreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolGreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*GreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolGreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*LessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolLessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*LessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolLessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*LessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolLessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*LessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolLessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*LessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolLessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*LessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*BoolLessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*LessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*BoolLessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*LessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolLessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); +} TsmRelation; + +typedef struct TsmLogic { + void (*NotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*AndVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*OrVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*XorVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*AndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*OrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*XorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); + void (*AndVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*OrVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*XorVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*BoolNotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BoolAndV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BoolOrV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BoolXorV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BoolAndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); + void (*BoolOrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); + void (*BoolXorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); + void (*BoolAndVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); + void (*BoolOrVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); + void (*BoolXorVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); +} TsmLogic; + +typedef struct TsmTranscendental { + void (*Log2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Ln)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Pow2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Exp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Explp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Sin)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Cos)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); +} TsmTranscendental; + +typedef struct TsmActivation { + void (*Tanh)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Sigmoid)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Relu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Satrelu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Leakyrelu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Softplus)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); +} TsmActivation; + +typedef struct TsmReduce { + void (*ReduceSum)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, + Data_Format fmt); + void (*ReduceAvg)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, + Data_Format fmt); + void (*ReduceMax)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, + Data_Format fmt); + void (*ReduceMin)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, + Data_Format fmt); +} TsmReduce; + +typedef struct TsmPool { + void (*MaxPool)(TsmPoolInstr *instr, uint64_t src0, Data_Shape src_shape, uint64_t dst, Data_Shape pad, + Data_Shape swr_shape, Data_Format fmt); + void (*AvgPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, + Data_Shape swr_shape, Data_Format fmt); + void (*SumPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, + Data_Shape swr_shape, Data_Format fmt); + void (*MinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, + Data_Shape swr_shape, Data_Format fmt); + void (*IndexdMinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_arg, + uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, Data_Format fmt); + void (*IndexdMaxPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_arg, + uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, Data_Format fmt); +} TsmPool; + +typedef struct TsmUnPool { + void (*Unpool)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); + void (*UnpoolAvg)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); + void (*UnpoolIdx)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); +} TsmUnPool; + +typedef struct TsmMaskDataMove { + void (*MaskMove)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t mask, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*MaskGather)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*MaskGather_bV)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t bitindex, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); +} TsmMaskDataMove; + + +typedef struct TsmConvert { + void (*INT8_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); // Data_Format fmt is INT8 + void (*INT8_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); + void (*INT8_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); + void (*INT8_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); + void (*INT16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); // INT16 + void (*INT16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*INT32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*BF16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BF16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*BF16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*BF16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BF16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*BF16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + + void (*FP16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); // rnd_mode 0~4 + void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*FP16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + + void (*FP32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); // rnd_mode 0~4 + void (*FP32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*TF32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*TF32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode);// rnd_mode 0~4 + void (*TF32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); +} TsmConvert; + +typedef struct TsmPeripheral { + void (*Count)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt, uint64_t *wb_data0, uint64_t *wb_data1); + void (*Memset)(TsmDataMoveInstr *instr, uint64_t dst_addr, uint32_t value, uint32_t elem_count, + St_StrideIteration *si, Data_Format fmt); // si.stride is byte size. but ele_count is only element count + void (*Bit2Fp)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*ArgMax)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t elem_count, Data_Format fmt); + void (*ArgMin)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t elem_count, Data_Format fmt); + void (*Bilinear)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst0_addr, Data_Shape src_shape, + Data_Shape dst_shape, int32_t scale_w, int32_t scale_h, Data_Format fmt); + void (*Lut16)(TsmPeripheralInstr *instr, uint64_t src1_addr, uint64_t dst0_addr, uint64_t lut16_addr, + uint32_t src_elem_count, uint32_t lut_elem_count); + void (*Lut32)(TsmPeripheralInstr *instr, uint64_t src1_addr, uint64_t dst0_addr, uint64_t lut32_addr, + uint32_t src_elem_count, uint32_t lut_elem_count); + void (*RandGen)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, + uint64_t dst1_addr, uint64_t dst2_addr, uint32_t src_elem_num, Data_Format fmt); + void (*Factorize)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint64_t dst1_addr, uint64_t dst2_addr, + uint32_t src_elem_num); + void (*ElemMask)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t scale, uint64_t dst_addr, uint32_t src_elem_num, Data_Format fmt, + uint32_t prob, RND_MODE rnd_mode); +} TsmPeripheral; + +typedef struct TsmDataMove { + void (*Mirror)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Transpose)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Rotate90)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Rotate180)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Rotate270)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Nchw2nhwc)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Nhwc2nchw)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Concat)(TsmMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape0, uint64_t src1_addr, + Data_Shape src_shape1, uint64_t dst_addr, Data_Shape dst_shape, uint32_t dims, Data_Format fmt); + void (*Pad)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Shape pad, Data_Format fmt); + void (*Img2col)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, uint64_t src_elem_num, uint64_t dst_elem_num, Data_Shape swr, Data_Shape pdr, + Data_Format fmt); + void (*TensorNom)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*GatherScatter)(TsmDataMoveInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t size, + St_StrideIteration *src_si, St_StrideIteration *dst_si); +} TsmDataMove; + +TsmConv *TsmNewConv(); +TsmDepthwiseConv *TsmNewDepthwiseConv(); +TsmGemm *TsmNewGemm(); +TsmRdma *TsmNewRdma(); +TsmWdma *TsmNewWdma(); +TsmArith *TsmNewArith(); +TsmRelation *TsmNewRelation(); +TsmLogic *TsmNewLogic(); +TsmTranscendental *TsmNewTranscendental(); +TsmActivation *TsmNewActivation(); +TsmReduce *TsmNewReduce(); +TsmPool *TsmNewPool(); +TsmUnPool *TsmNewUnPool(); +TsmMaskDataMove *TsmNewMaskDataMove(); +TsmConvert *TsmNewConvert(); +TsmPeripheral *TsmNewPeripheral(); +TsmDataMove *TsmNewDataMove(); + +void TsmDeleteConv(TsmConv *obj); +void TsmDeleteDepthwiseConv(TsmDepthwiseConv *obj); +void TsmDeleteGemm(TsmGemm *obj); +void TsmDeleteRdma(TsmRdma *obj); +void TsmDeleteWdma(TsmWdma *obj); +void TsmDeleteArith(TsmArith *obj); +void TsmDeleteRelation(TsmRelation *obj); +void TsmDeleteLogic(TsmLogic *obj); +void TsmDeleteTranscendental(TsmTranscendental *obj); +void TsmDeleteActivation(TsmActivation *obj); +void TsmDeleteReduce(TsmReduce *obj); +void TsmDeletePool(TsmPool *obj); +void TsmDeleteUnPool(TsmUnPool *obj); +void TsmDeleteMaskDataMove(TsmMaskDataMove *obj); +void TsmDeleteConvert(TsmConvert *obj); +void TsmDeletePeripheral(TsmPeripheral *obj); +void TsmDeleteDataMove(TsmDataMove *obj); +/*=================================STREAM=================================*/ +typedef struct TsmStream { + uint32_t (*OnlineStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint32_t (*OfflineStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint32_t (*WaitStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint32_t (*ReqStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint32_t (*PushStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint32_t (*PopStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); + uint8_t (*wait_finish)(); +} TsmStream; +TsmStream *TsmNewStream(); +void TsmDeleteStream(TsmStream *obj); +/*=================================CSR=====================================*/ +uint8_t TsmWaitfinish(); +uint8_t TsmGetCsrTaskstatus(); +uint8_t TsmGetCsrIbcounter(); +uint8_t TsmGetCsrTaskstatus_bywork(size_t workerid); +uint8_t TsmWaitfinish_bywork(size_t workerid); +/*=================================CSR END=================================*/ +#ifdef __cplusplus +} +#endif + +// ==================== if you will use Tx8-Oplib ====================================================== +// #define LOG_PRINT(...) +// #define LOG_ERR(fmt, args...) +// #define TSM_FREE free +// #define TSM_MALLOC malloc +// extern void setreg(int index, uint64_t value); +// extern uint64_t getreg(int index); + +// ==================== if you run in SOC-freerots/zebu ================================================ +// #include "rce_log.h" +// #include "csi_kernel.h" +// #include "rce_pal.h" +// #define LOG_PRINT(fmt, args...) vdk_printf(fmt, ##args) +// #define LOG_ERR(fmt, args...) vdk_printf(fmt, ##args) +// #define TSM_FREE(target) csi_kernel_free(2, target, NULL) +// #define TSM_MALLOC(size) csi_kernel_malloc(2, size, NULL) +// #define NCC_ADDR 0x01000000 +// #define setreg(ADDR, VALUE) +// do { +// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); +// *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; +// } while (0) + +// ==================== if you run in kernel-rt(use tile-sim) ========================================== +// #define LOG_PRINT(fmt, args...) printf(fmt, ##args) +// #define LOG_ERR(fmt, args...) printf(fmt, ##args) +// #define TSM_FREE free +// #define TSM_MALLOC malloc +// #include "rce_pal_port.h" +// void setreg(int index, uint64_t value) +// { +// LOG_PRINT("setreg param: GR: index=0x%X, value=0x%lX(%lu).\n", index, value, value); +// rce_tx_pal_setreg(index, value); +// } + +// ====================if you run in kernel-rt(use riscv) =============================================== +// #define LOG_PRINT(...) +// #define LOG_ERR(fmt, args...) +// #define TSM_FREE free +// #define TSM_MALLOC malloc +// #define NCC_ADDR 0x01000000 +// #define setreg(ADDR, VALUE) +// do { +// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); +// *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; +// } while (0) + +// ====================if you do not need Log ========================================================== +//#define LOG_PRINT(...) +//#define LOG_ERR(fmt, args...) +#endif diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_def.h b/third_party/tsingmicro/crt/include/Tx81/instr_def.h new file mode 100644 index 000000000..96c8eda33 --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/instr_def.h @@ -0,0 +1,924 @@ +#ifndef _RCE_INSTR_DEF_H_ +#define _RCE_INSTR_DEF_H_ +#include + +#define UN_USED 0 + +// CT +#define GR_CT_CONTROL_ADDR 0x0000 +#define GR_CT_SRC0_ADDR 0x0008 * 2 +#define GR_CT_SRC1_ADDR 0x0010 * 2 +#define GR_CT_DST0_ADDR 0x0018 * 2 +#define GR_CT_DST1_ADDR 0x0020 * 2 +#define GR_CT_DST2_ADDR 0x0028 * 2 +#define GR_CT_DIMS_ADDR 0x0030 * 2 +#define GR_CT_SRC0_TFR_ADDR 0x0038 * 2 +#define GR_CT_DST_TFR_ADDR 0x0040 * 2 +#define GR_CT_PDR_ADDR 0x0048 * 2 +#define GR_CT_SWR_ADDR 0x0050 * 2 +#define GR_CT_ELEM_COUNT_ADDR 0x0058 * 2 +#define GR_CT_UNIT_ELEM_COUNT_ADDR 0x0060 * 2 +#define GR_CT_INT8_SCALE_VAL0_ADDR 0x0068 * 2 +#define GR_CT_INT8_SCALE_VECTOR_ADDR 0x0068 * 2 +#define GR_CT_INT8_SCALE_VAL1_ADDR 0x0070 * 2 +#define GR_CT_INT8_QUANT_ADDR 0x0078 * 2 +#define GR_CT_INT8_BN_ZP_ADDR 0x0080 * 2 +#define GR_CT_FULL_ELEM_COUNT_ADDR 0x0088 * 2 +#define GR_CT_FULL_UNIT_ELEM_COUNT_ADDR 0x0090 * 2 +#define GR_CT_WB_DATA0_ADDR 0x0098 * 2 +#define GR_CT_WB_DATA1_ADDR 0x00A0 * 2 +#define GR_CT_SRC0_END_ADDR 0x00A8 * 2 +#define GR_CT_SRC1_END_ADDR 0x00B0 * 2 +#define GR_CT_DST0_END_ADDR 0x00B8 * 2 +#define GR_CT_DST1_END_ADDR 0x00C0 * 2 +#define GR_CT_DST2_END_ADDR 0x00C8 * 2 + +// NE +#define GR_NE_CONTROL_ADDR 0x0100 * 2 +#define GR_NE_SRC_A_ADDR 0x0108 * 2 +#define GR_NE_SRC_W_ADDR 0x0110 * 2 +#define GR_NE_PSUM_ADDR 0x0118 * 2 +#define GR_NE_BIAS_ADDR 0x0120 * 2 +#define GR_NE_SCALE_P_ADDR 0x0128 * 2 +#define GR_NE_SCALE_N_ADDR 0x0130 * 2 +#define GR_NE_OUT_ADDR 0x0138 * 2 +#define GR_NE_SRC0_TFR_ADDR 0x0140 * 2 +#define GR_NE_SRC1_OUT_TFR_ADDR 0x0148 * 2 +#define GR_NE_PDR_ADDR 0x0150 * 2 +#define GR_NE_UNPDR_ADDR 0x0158 * 2 +#define GR_NE_SWR_ADDR 0x0160 * 2 +#define GR_NE_DILATION_ADDR 0x0168 * 2 +#define GR_NE_GEMM_LB_ADDR 0x0170 * 2 +#define GR_NE_GEMM_RB_ADDR 0x0178 * 2 +#define GR_NE_GEMM_N_ADDR 0x0180 * 2 +#define GR_NE_GEMM_M_ADDR 0x0188 * 2 +#define GR_NE_GEMM_K_ADDR 0x0190 * 2 +#define GR_NE_GEMM_L_TRS_ADDR 0x0198 * 2 +#define GR_NE_GEMM_R_TRS_ADDR 0x01A0 * 2 +#define GR_NE_QUANT_ADDR 0x01A8 * 2 +#define GR_NE_SPARSE_INDEX_ADDR 0x01B0 * 2 +#define GR_NE_SRCA_END 0x370 +#define GR_NE_SRCW_END 0x380 +#define GR_NE_PSUM_END 0x390 +#define GR_NE_BIAS_END 0x3A0 +#define GR_NE_SCALE_P_END 0x3B0 +#define GR_NE_SCALE_N_END 0x3C0 +#define GR_NE_OUT_END 0x3D0 +#define GR_NE_SPARSE_INDEX_END 0x3E0 + +// LSU RDMA +#define GR_RD_CONTROL_ADDR 0x400 +#define GR_RD_SRC_ADDR 0x410 +#define GR_RD_DST_ADDR 0x420 +#define GR_RD_STRIDE_ITERA0_ADDR 0x430 +#define GR_RD_STRIDE_ITERA1_ADDR 0x440 +#define GR_RD_STRIDE_ITERA2_ADDR 0x450 +#define GR_RD_ELEM_COUNT_ADDR 0x460 +#define GR_RD_FORMAT_ADDR 0x470 +#define GR_RD_SRC_END 0x480 +#define GR_RD_DST_END 0x490 +// LSU WDMA +#define GR_WD_CONTROL_ADDR 0x4A0 +#define GR_WD_SRC_ADDR 0x4B0 +#define GR_WD_DST_ADDR 0x4C0 +#define GR_WD_STRIDE_ITERA0_ADDR 0x4D0 +#define GR_WD_STRIDE_ITERA1_ADDR 0x4E0 +#define GR_WD_STRIDE_ITERA2_ADDR 0x4F0 +#define GR_WD_ELEM_COUNT_ADDR 0x500 +#define GR_WD_FORMAT_ADDR 0x510 +#define GR_WD_SRC_END 0x520 +#define GR_WD_DST_END 0x530 + +// TMDA +#define GR_TD_CONTROL_ADDR 0x02A0 * 2 +#define GR_TD_SRC0_ADDR 0x02A8 * 2 +#define GR_TD_SRC1_ADDR 0x02B0 * 2 +#define GR_TD_DST_ADDR 0x02B8 * 2 +#define GR_TD_DIMS_ADDR 0x02C0 * 2 +#define GR_TD_SRC0_TFR_ADDR 0x02C8 * 2 +#define GR_TD_DST_TFR_ADDR 0x02D0 * 2 +#define GR_TD_PDR_ADDR 0x02D8 * 2 +#define GR_TD_SWR_ADDR 0x02E0 * 2 +#define GR_TD_ELEM_COUNT_ADDR 0x02E8 * 2 +#define GR_TD_SRC_STRIDE_ITERA0_ADDR 0x02F0 * 2 +#define GR_TD_SRC_STRIDE_ITERA1_ADDR 0x02F8 * 2 +#define GR_TD_SRC_STRIDE_ITERA2_ADDR 0x0300 * 2 +#define GR_TD_DST_STRIDE_ITERA0_ADDR 0x0308 * 2 +#define GR_TD_DST_STRIDE_ITERA1_ADDR 0x0310 * 2 +#define GR_TD_DST_STRIDE_ITERA2_ADDR 0x0318 * 2 +#define GR_TD_SRC0_END 0x640 +#define GR_TD_SRC1_END 0x650 +#define GR_TD_DST_END 0x660 +// SCALAR +#define GR_SCALAR_CONTROL_ADDR 0x6A0 +#define GR_SCALAR_SRC_ADDR 0x6B0 +#define GR_SCALAR_DST_ADDR 0x6C0 +// CSR +#define GR_CSR_CONTROL_ADDR 0x740 +#define GR_CSR_EXCEPTION_ADDR 0x750 +#define GR_CSR_PRIORITY_ADDR 0x760 +#define GR_CSR_EXCEPTION_MASK_ADDR 0x770 +#define GR_CSR_SERIAL_MODE_ADDR 0x780 + +// CSR end +// DTE start +#define GR_DTE_SRC_ADDR_LO 0x0 +#define GR_DTE_SRC_ADDR_HI 0x4 +#define GR_DTE_DST_ADDR_LO_0 0x8 +#define GR_DTE_DST_ADDR_HI_0 0xC +#define GR_DTE_USER_ID_0 0x10 +#define GR_DTE_MODE 0x14 +#define GR_DTE_LENGTH 0x18 +#define GR_DTE_DEST_NUM 0x1C +#define GR_DTE_STRIDE0 0x20 +#define GR_DTE_ITERATION0 0x24 +#define GR_DTE_STRIDE1 0x28 +#define GR_DTE_ITERATION1 0x2C +#define GR_DTE_STRIDE2 0x30 +#define GR_DTE_ITERATION2 0x34 +#define GR_DTE_CMD_VALID 0x38 +#define GR_DTE_DMA_STATUS 0x40 +#define GR_DTE_DST_ADDR_LO_1 0x50 +#define GR_DTE_DST_ADDR_HI_1 0x54 +#define GR_DTE_DST_ADDR_LO_2 0x58 +#define GR_DTE_DST_ADDR_HI_2 0x5C +#define GR_DTE_DST_ADDR_LO_3 0x60 +#define GR_DTE_DST_ADDR_HI_3 0x64 +#define GR_DTE_DST_ADDR_LO_4 0x68 +#define GR_DTE_DST_ADDR_HI_4 0x6C +#define GR_DTE_DST_ADDR_LO_5 0x70 +#define GR_DTE_DST_ADDR_HI_5 0x74 +#define GR_DTE_DST_ADDR_LO_6 0x78 +#define GR_DTE_DST_ADDR_HI_6 0x7C +#define GR_DTE_DST_ADDR_LO_7 0x80 +#define GR_DTE_DST_ADDR_HI_7 0x84 +#define GR_DTE_DST_ADDR_LO_8 0x88 +#define GR_DTE_DST_ADDR_HI_8 0x8C +#define GR_DTE_DST_ADDR_LO_9 0x90 +#define GR_DTE_DST_ADDR_HI_9 0x94 +#define GR_DTE_DST_ADDR_LO_10 0x98 +#define GR_DTE_DST_ADDR_HI_10 0x9C +#define GR_DTE_DST_ADDR_LO_11 0xA0 +#define GR_DTE_DST_ADDR_HI_11 0xA4 +#define GR_DTE_DST_ADDR_LO_12 0xA8 +#define GR_DTE_DST_ADDR_HI_12 0xAC +#define GR_DTE_DST_ADDR_LO_13 0xB0 +#define GR_DTE_DST_ADDR_HI_13 0xB4 +#define GR_DTE_DST_ADDR_LO_14 0xB8 +#define GR_DTE_DST_ADDR_HI_14 0xBC +#define GR_DTE_DST_ADDR_LO_15 0xC0 +#define GR_DTE_DST_ADDR_HI_15 0xC4 +#define GR_DTE_DST_ADDR_LO_16 0xC8 +#define GR_DTE_DST_ADDR_HI_16 0xCC +#define GR_DTE_DST_ADDR_LO_17 0xD0 +#define GR_DTE_DST_ADDR_HI_17 0xD4 +#define GR_DTE_DST_ADDR_LO_18 0xD8 +#define GR_DTE_DST_ADDR_HI_18 0xD4 +#define GR_DTE_DST_ADDR_LO_19 0xE0 +#define GR_DTE_DST_ADDR_HI_19 0xE4 +#define GR_DTE_DST_ADDR_LO_20 0xE8 +#define GR_DTE_DST_ADDR_HI_20 0xEC +#define GR_DTE_DST_ADDR_LO_21 0xF0 +#define GR_DTE_DST_ADDR_HI_21 0xF4 +#define GR_DTE_DST_ADDR_LO_22 0xF8 +#define GR_DTE_DST_ADDR_HI_22 0xFC +#define GR_DTE_DST_ADDR_LO_23 0x100 +#define GR_DTE_DST_ADDR_HI_23 0x104 +#define GR_DTE_DST_ADDR_LO_24 0x108 +#define GR_DTE_DST_ADDR_HI_24 0x10C +#define GR_DTE_DST_ADDR_LO_25 0x110 +#define GR_DTE_DST_ADDR_HI_25 0x114 +#define GR_DTE_DST_ADDR_LO_26 0x118 +#define GR_DTE_DST_ADDR_HI_26 0x11C +#define GR_DTE_DST_ADDR_LO_27 0x120 +#define GR_DTE_DST_ADDR_HI_27 0x124 +#define GR_DTE_DST_ADDR_LO_28 0x128 +#define GR_DTE_DST_ADDR_HI_28 0x12C +#define GR_DTE_DST_ADDR_LO_29 0x130 +#define GR_DTE_DST_ADDR_HI_29 0x134 +#define GR_DTE_DST_ADDR_LO_30 0x138 +#define GR_DTE_DST_ADDR_HI_30 0x13C +#define GR_DTE_DST_ADDR_LO_31 0x140 +#define GR_DTE_DST_ADDR_HI_31 0x144 + +#define GR_DTE_USER_ID_1 0x148 +#define GR_DTE_USER_ID_2 0x14C +#define GR_DTE_USER_ID_3 0x150 +#define GR_DTE_USER_ID_4 0x154 +#define GR_DTE_USER_ID_5 0x158 +#define GR_DTE_USER_ID_6 0x15C +#define GR_DTE_USER_ID_7 0x160 +#define GR_DTE_USER_ID_8 0x164 +#define GR_DTE_USER_ID_9 0x168 +#define GR_DTE_USER_ID_10 0x16C +#define GR_DTE_USER_ID_11 0x170 +#define GR_DTE_USER_ID_12 0x174 +#define GR_DTE_USER_ID_13 0x178 +#define GR_DTE_USER_ID_14 0x17C +#define GR_DTE_USER_ID_15 0x180 +#define GR_DTE_USER_ID_16 0x184 +#define GR_DTE_USER_ID_17 0x188 +#define GR_DTE_USER_ID_18 0x18C +#define GR_DTE_USER_ID_19 0x190 +#define GR_DTE_USER_ID_20 0x194 +#define GR_DTE_USER_ID_21 0x198 +#define GR_DTE_USER_ID_22 0x19C +#define GR_DTE_USER_ID_23 0x1A0 +#define GR_DTE_USER_ID_24 0x1A4 +#define GR_DTE_USER_ID_25 0x1A8 +#define GR_DTE_USER_ID_26 0x1AC +#define GR_DTE_USER_ID_27 0x1B0 +#define GR_DTE_USER_ID_28 0x1B4 +#define GR_DTE_USER_ID_29 0x1B8 +#define GR_DTE_USER_ID_30 0x1BC +#define GR_DTE_USER_ID_31 0x1C0 + +#define GR_DTE_MAX_AXI_NUM 0x1D0 +#define GR_DTE_MEM_BURSTLEN 0x1D4 +#define GR_DTE_MEM_BACKPRESSURE 0x1D8 +#define GR_DTE_MEM_READ_TURBO 0x1DC +// DTE end + +// SCONFIG begin +#define GR_SCONFIG_GPR0 0x600 + +// SCONFIG end + +// NCC PMU begin +#define GR_PMU_EN 0x0 +#define GR_PMU_CLR 0x4 +#define GR_PMU_STATISTICS_WINDOW 0x8 +#define GR_PMU_CT_INST_NUMS 0x10 +#define GR_PMU_NE_INST_NUMS 0x14 +#define GR_PMU_RDMA_INST_NUMS 0x18 +#define GR_PMU_WDMA_INST_NUMS 0x1C +#define GR_PMU_TDMA_INST_NUMS 0x20 +#define GR_PMU_SCALAR_INST_NUMS 0x24 +#define GR_PMU_CT_BLOCKING_TIME 0x28 +#define GR_PMU_NE_BLOCKING_TIME 0x2C +#define GR_PMU_RDMA_BLOCKING_TIME 0x30 +#define GR_PMU_WDMA_BLOCKING_TIME 0x34 +#define GR_PMU_TDMA_BLOCKING_TIME 0x38 +#define GR_PMU_SCALAR_BLOCKING_TIME 0x3c + +#define GR_PMU_FU_EXE_TIME 0x13c +#define GR_PMU_CT_EXE_TIME 0x144 +#define GR_PMU_NE_EXE_TIME 0x14c +#define GR_PMU_RDMA_EXE_TIME 0x154 +#define GR_PMU_WDMA_EXE_TIME 0x15c +#define GR_PMU_TDMA_EXE_TIME 0x164 +#define GR_PMU_SCALAR_EXE_TIME 0x16c +// NCC PMU end + +// DTE PMU begin +#define DTE_PMU_EN 0x800 +#define DTE_PMU_CLR 0x804 + +#define DTE_PMU_CH0_L_EXE_TIME 0x858 +#define DTE_PMU_CH0_H_EXE_TIME 0x85C +#define DTE_PMU_CH1_L_EXE_TIME 0x860 +#define DTE_PMU_CH1_H_EXE_TIME 0x864 +// DTE PMU end + +typedef enum OP_INSTR_TYPE { + I_CGRA, + I_NEUR, + I_RDMA, + I_WDMA, + I_TDMA, + I_SCALAR, + I_DTE, + I_CSR, +} OP_INSTR_TYPE; +// instr_type = I_CGRA | I_WORKER1 +typedef enum OP_INSTR_WORKER { + I_WORKER0 = 0x0000, + I_WORKER1 = 0x0100, + I_WORKER2 = 0x0200, +} OP_INSTR_WORKER; + +typedef enum RND_MODE { + RND_NEAREST_EVEN, + RND_ZERO, + RND_POS_INF, + RND_NEG_INF, + RND_STOCHASTIC +} RND_MODE; + + +typedef struct Ncc_CT_GR_Ctl_Regs { + uint8_t cmd_valid; // self clear + uint8_t rnd_mode; // 0 :round to nearest even , 1 :round to zero, 2 :round to positive infinity, 3 :round to + // negative infinity, 4 :stochastic round + uint8_t src0_format; // 当CGRATensor_PeriOp_V_V_bit2fp指令,此字段用作dst_format + uint8_t opcode; // 详见CGRATensor指令OPcode.v +} Ncc_CT_GR_Ctl_Regs; + +typedef struct Ncc_CT_GR_Param_Regs { + uint32_t src0; // spm地址 + uint32_t src1; + uint32_t dst0; + uint32_t dst1; + uint32_t dst2; // spm地址 + uint64_t src0_tfr; // nhwc + uint64_t dst_tfr; // nhwc + uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) + uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy + uint64_t elem_count; // vector运算的元素个数 + uint64_t unit_elem_count; // vector运算中的短向量的元素个数(最大为64) + uint64_t int8_scale_val0; // 双线性插值x方向缩放系数(input_w/output_w) + uint64_t int8_scale_val1; // 双线性插值y方向缩放系数(input_h/output_h) + uint64_t int8_quant; // abandon + uint32_t int8_bn_bias; // abandon + uint32_t full_elem_count; // 若干个src_elem_num之和 + uint32_t full_unit_elem_count; // 若干个src_uint_elem_num之和 + uint64_t wb_data0; // The pointer of Return value. [32] DATA_VALID, [31:0] data, + // 函数只有一个返回值时,返回数据写在此寄存器 + uint64_t wb_data1; // The pointer of Return value. [32] DATA_VALID, [31:0] data, + // 函数有两个返回值时,第二个返回数据写在此寄存器,当只有一个返回值时,此寄存器无效 + uint32_t src0_end; // spm地址(src0结束地址), xxx_end = src/dst + 对应操作数在spm中存储范围 + uint32_t src1_end; + uint32_t dst0_end; + uint32_t dst1_end; + uint32_t dst2_end; + uint8_t dims; // 000:C 001:W 010:H 011:N 100:HW 101:HWC +} Ncc_CT_GR_Param_Regs; + +typedef struct CT_Param { + uint32_t inter_type; + Ncc_CT_GR_Ctl_Regs ctrl; + Ncc_CT_GR_Param_Regs param; +} CT_Param; + +#define TsmArithInstr CT_Param +#define TsmPoolInstr CT_Param +#define TsmMoveInstr CT_Param +#define TsmUnPoolInstr CT_Param +#define TsmMaskDataMoveInstr CT_Param +#define TsmConvertInstr CT_Param +#define TsmPeripheralInstr CT_Param +#define TsmRelationInstr CT_Param +#define TsmLogicInstr CT_Param +#define TsmTranscendentalInstr CT_Param +#define TsmActivationInstr CT_Param +#define TsmReduceInstr CT_Param + +typedef struct Ncc_NE_GR_Ctl_Regs { + uint8_t sparse_en; + uint8_t cmd_valid; + uint8_t inpsum_format; + uint8_t output_format; + uint8_t input_format; + uint8_t inpsum_en; + uint8_t lrelu_en; // either relu or lrelu + uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum + uint8_t scale_en; + uint8_t bias_en; + uint8_t dilation_conv; // valid as conv backwardconv + uint8_t type; // 0:conv 1:depthwise conv 2:backward conv 3:gemm +} Ncc_NE_GR_Ctl_Regs; + +typedef struct Ncc_NE_GR_Param_Regs { + uint32_t src_a; // spm地址(激活/左矩阵) + uint32_t src_w; // spm地址(权重/右矩阵) + uint32_t psum; // spm地址(输入psum) + uint32_t bias; // spm地址(bias) + uint32_t scale_p; // spm地址(正轴scale) + uint32_t scale_n; // spm地址(负轴scale) + uint32_t out; // spm地址(输出psum) + uint64_t tfr_0; // src0 nhwc, [15:0]tensor batch/h/w(范围1~4096);tensor通道数(范围1~16384) + uint64_t tfr_1; // conv: out nhwc, 同上tfr_0 + uint64_t pdr; // pad [15:0]top bottom left right, 分别是上下左右pad的行/列数(范围0~1023) + uint64_t unpdr; // unpad [15:0]top bottom left right + uint64_t swr; // [15:0]Kx(范围1~255) Ky(范围1~255) Sx(范围1~1023) Sy(范围1~1023) + uint64_t dilation; // [15:0]空洞卷积的x方向大小(范围1-1023), [15:0]空洞卷积的y方向大小(范围1-1023) + + uint16_t gemm_lb; // [15:0]左矩阵batch(范围:1~4096) + uint16_t gemm_rb; // [15:0]左矩阵batch(范围:1~4096) + uint16_t gemm_n; // 矩阵运算的矩阵大小参数 + uint16_t gemm_m; // mk*kn---->mn + uint16_t gemm_k; // (范围:1~16384) + uint8_t gemm_l_trs; // 左矩阵转置 + uint8_t gemm_r_trs; // 右矩阵转置 + /* + Quant formula----A_int8:Left input, B_int8: Right input + Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 + Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 + do conv : O_int32 = Sum_{A_int9 * B_int9} + do scale : O_int16 = Clip_int16(O_int32 >> q1) + do scale : O_int9 = Clip_int9((O_int16 * S_int16) >> q2) + out 9bit to 8bit : O_int8 = O_int9 + ZP_O_int8 + */ + uint8_t quant_zp_cur; // 输出零点(0-255). [39:32] + uint8_t quant_reserved; // (0-255). [31:24] conv:unused gemm:right_zp + uint8_t quant_zp_pre; // 输入零点(0-255), [23:16] conv:act_zp gemm:left_zp(范围:0-255) + uint8_t quant_q1; // q1, (范围:0-31),[15:8] + uint8_t quant_q0; // q2, (范围:0-31),[7:0] + + uint32_t sparse_index; // spm地址(稀疏化索引) + uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 + uint32_t srcw_end; + uint32_t psum_end; + uint32_t bias_end; + uint32_t scale_p_end; + uint32_t scale_n_end; + uint32_t out_end; + uint32_t sparse_end; +} Ncc_NE_GR_Param_Regs; + +typedef struct TsmNeInstr { + uint32_t inter_type; + Ncc_NE_GR_Ctl_Regs ctrl; + Ncc_NE_GR_Param_Regs param; +} TsmNeInstr; + +// RDMA / WDMA +typedef struct Ncc_DMA_GR_Ctl_Regs { + uint8_t cmd_valid; +} Ncc_DMA_GR_Ctl_Regs; + +typedef struct Ncc_DMA_GR_Param_Regs { + uint64_t dst; // ddr地址 + uint64_t src; // spm地址 + /* + for(i = 0; i < itera2; i++) + for(j = 0; j < itera1; j++) + for(k = 0; k < itera0; k++) + for(l = 0; l < elem_count; l++) + dst[l + elem_coun * k + elem_coun * src_itera0 * j + elem_coun * src_itera0 * src_itera1 * i] = + \ src[l + k * src_stride0 + j * src_stride1 + i * src_stride2]; + */ + uint32_t stride0; //地址步长 + uint32_t iteration0; // 数据块个数 + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; + uint32_t elem_count; // 最里面维度单次搬运的元素个数 + uint8_t format; // 数据类型 + uint64_t src_end; // src_end = src + ddr中数据存储长度 + uint64_t dst_end; // dst_end = dst + spm中数据存储长度 +} Ncc_DMA_GR_Param_Regs; + +typedef struct DMA_Param { + uint32_t inter_type; + Ncc_DMA_GR_Ctl_Regs ctrl; + Ncc_DMA_GR_Param_Regs param; +} DMA_Param; + +#define TsmRdmaInstr DMA_Param +#define TsmWdmaInstr DMA_Param + +typedef struct Ncc_TDMA_GR_Ctl_Regs { + uint8_t cmd_valid; // [12] + uint8_t src0_format; // [11:8] + uint8_t opcode; //[7:0] +} Ncc_TDMA_GR_Ctl_Regs; + +typedef struct Ncc_TDMA_GR_Param_Regs { + uint32_t src0; + uint32_t src1; + uint32_t dst; + uint64_t src0_tfr; // nhwc c:15~0 + uint64_t dst_tfr; // nhwc + uint64_t pdr; // top bottom left right + uint64_t swr; // kx ky sx sy + uint32_t elem_count; // vector操作的元素个数. memset、gatherscatter指令中代表byte number + /* + for(i=0;igather/scatter, unicast, 1-> scatter, broadcast, 3-> shuffle(3D gather). [8:8] sg_flag: + // 0->scatter, 1->gather, [16:16] dim_flag: only unicast mode(=0), 0->1D transport, 1->2D transport + uint32_t length; // count data bytes + uint8_t dest_num; // if mode[0:0] is 0, then it's value is 1; otherwise it's value is between 1 and 31. + uint32_t stride0; // if mode[0:0] is 0, then stride can be setted, unit is byte. + uint32_t iteration0; // 0: means 1 section; 1: means 2 sectons, and so on. + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; + uint16_t max_axi_num; // [7:0] axi_write_outstanding, [15:8] aix_read_outstanding + uint8_t cmd_valid; // 1: activate dma, 0: no action. + // uint8_t dma_status; // [0:0] 0->unfinished, 1->finished. [8:8] 0/1, record the error of AXI bus or other DMA + // transmission. + uint16_t + mem_burstlen; // [7:0] mem_burst_len_write, default value: 0x10; [15:8] mem_burst_len_read, default value: 0x10 + uint8_t mem_backpressure; // 0x1 + uint8_t mem_read_turbo; // [1:0], 0~2, default value: 0, only block0 valid. +} Ncc_DTE_GR_Param_Regs; + +typedef enum OP_FUNC_CGRA { + // Arithmetic Operators + OP_FUNC_CGRATensor_ArithOp_V_V_abs = 0, + OP_FUNC_CGRATensor_ArithOp_V_V_recip = 1, + OP_FUNC_CGRATensor_ArithOp_V_V_square = 2, + OP_FUNC_CGRATensor_ArithOp_V_V_sqrt = 3, + OP_FUNC_CGRATensor_ArithOp_V_V_rsqrt = 4, + OP_FUNC_CGRATensor_ArithOp_V_V_neg = 5, + OP_FUNC_CGRATensor_ArithOp_V_VV_max = 6, + OP_FUNC_CGRATensor_ArithOp_V_VS_max = 7, + OP_FUNC_CGRATensor_ArithOp_V_VuV_max = 8, + OP_FUNC_CGRATensor_ArithOp_V_VuV_max_loop = 9, + OP_FUNC_CGRATensor_ArithOp_V_VV_min = 10, + OP_FUNC_CGRATensor_ArithOp_V_VS_min = 11, + OP_FUNC_CGRATensor_ArithOp_V_VuV_min = 12, + OP_FUNC_CGRATensor_ArithOp_V_VuV_min_loop = 13, + OP_FUNC_CGRATensor_ArithOp_V_VV_add = 14, + OP_FUNC_CGRATensor_ArithOp_V_VS_add = 15, + OP_FUNC_CGRATensor_ArithOp_V_VuV_add = 16, + OP_FUNC_CGRATensor_ArithOp_V_VuV_add_loop = 17, + OP_FUNC_CGRATensor_ArithOp_V_VV_sub = 18, + OP_FUNC_CGRATensor_ArithOp_V_VS_sub = 19, + OP_FUNC_CGRATensor_ArithOp_V_VuV_sub = 20, + OP_FUNC_CGRATensor_ArithOp_V_VuV_sub_loop = 21, + OP_FUNC_CGRATensor_ArithOp_V_VV_mul = 22, + OP_FUNC_CGRATensor_ArithOp_V_VS_mul = 23, + OP_FUNC_CGRATensor_ArithOp_V_VuV_mul = 24, + OP_FUNC_CGRATensor_ArithOp_V_VuV_mul_loop = 25, + OP_FUNC_CGRATensor_ArithOp_V_VV_div = 26, + OP_FUNC_CGRATensor_ArithOp_V_VS_div = 27, + OP_FUNC_CGRATensor_ArithOp_V_VuV_div = 28, + OP_FUNC_CGRATensor_ArithOp_V_VuV_div_loop = 29, + + // Relational Operators + OP_FUNC_CGRATensor_RelaOp_V_VV_eq = 30, + OP_FUNC_CGRATensor_RelaOp_bV_VV_eq = 31, + OP_FUNC_CGRATensor_RelaOp_V_VS_eq = 32, + OP_FUNC_CGRATensor_RelaOp_bV_VS_eq = 33, + OP_FUNC_CGRATensor_RelaOp_V_VuV_eq = 34, + OP_FUNC_CGRATensor_RelaOp_V_VuV_eq_loop = 35, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq = 36, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq_loop = 37, + + OP_FUNC_CGRATensor_RelaOp_V_VV_ne = 38, + OP_FUNC_CGRATensor_RelaOp_bV_VV_ne = 39, + OP_FUNC_CGRATensor_RelaOp_V_VS_ne = 40, + OP_FUNC_CGRATensor_RelaOp_bV_VS_ne = 41, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ne = 42, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ne_loop = 43, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne = 44, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne_loop = 45, + + OP_FUNC_CGRATensor_RelaOp_V_VV_ge = 46, + OP_FUNC_CGRATensor_RelaOp_bV_VV_ge = 47, + OP_FUNC_CGRATensor_RelaOp_V_VS_ge = 48, + OP_FUNC_CGRATensor_RelaOp_bV_VS_ge = 49, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ge = 50, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ge_loop = 51, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge = 52, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge_loop = 53, + + OP_FUNC_CGRATensor_RelaOp_V_VV_gt = 54, + OP_FUNC_CGRATensor_RelaOp_bV_VV_gt = 55, + OP_FUNC_CGRATensor_RelaOp_V_VS_gt = 56, + OP_FUNC_CGRATensor_RelaOp_bV_VS_gt = 57, + OP_FUNC_CGRATensor_RelaOp_V_VuV_gt = 58, + OP_FUNC_CGRATensor_RelaOp_V_VuV_gt_loop = 59, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt = 60, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt_loop = 61, + + OP_FUNC_CGRATensor_RelaOp_V_VV_le = 62, + OP_FUNC_CGRATensor_RelaOp_bV_VV_le = 63, + OP_FUNC_CGRATensor_RelaOp_V_VS_le = 64, + OP_FUNC_CGRATensor_RelaOp_bV_VS_le = 65, + OP_FUNC_CGRATensor_RelaOp_V_VuV_le = 66, + OP_FUNC_CGRATensor_RelaOp_V_VuV_le_loop = 67, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_le = 68, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_le_loop = 69, + + OP_FUNC_CGRATensor_RelaOp_V_VV_lt = 70, + OP_FUNC_CGRATensor_RelaOp_bV_VV_lt = 71, + OP_FUNC_CGRATensor_RelaOp_V_VS_lt = 72, + OP_FUNC_CGRATensor_RelaOp_bV_VS_lt = 73, + OP_FUNC_CGRATensor_RelaOp_V_VuV_lt = 74, + OP_FUNC_CGRATensor_RelaOp_V_VuV_lt_loop = 75, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt = 76, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt_loop = 77, + + OP_FUNC_CGRATensor_LogicOp_V_V_not = 78, + OP_FUNC_CGRATensor_LogicOp_V_VV_and = 79, + OP_FUNC_CGRATensor_LogicOp_V_VV_or = 80, + OP_FUNC_CGRATensor_LogicOp_V_VV_xor = 81, + OP_FUNC_CGRATensor_LogicOp_V_VuV_and = 82, + OP_FUNC_CGRATensor_LogicOp_V_VuV_or = 83, + OP_FUNC_CGRATensor_LogicOp_V_VuV_xor = 84, + OP_FUNC_CGRATensor_LogicOp_V_VuV_and_loop = 85, + OP_FUNC_CGRATensor_LogicOp_V_VuV_or_loop = 86, + OP_FUNC_CGRATensor_LogicOp_V_VuV_xor_loop = 87, + + OP_FUNC_CGRATensor_LogicOp_bV_bV_not = 88, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_and = 89, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_or = 90, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_xor = 91, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and = 92, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or = 93, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor = 94, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and_loop = 95, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or_loop = 96, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor_loop = 97, + + // Transcendental Operator + OP_FUNC_CGRATensor_TransOp_V_V_log2 = 98, + OP_FUNC_CGRATensor_TransOp_V_V_ln = 99, + OP_FUNC_CGRATensor_TransOp_V_V_pow2 = 100, + OP_FUNC_CGRATensor_TransOp_V_V_exp = 101, + OP_FUNC_CGRATensor_TransOp_V_V_exp_lp = 102, + OP_FUNC_CGRATensor_TransOp_V_V_sin = 103, + OP_FUNC_CGRATensor_TransOp_V_V_cos = 104, + + // Activation Operator + OP_FUNC_CGRATensor_ActOp_V_V_tanh = 105, + OP_FUNC_CGRATensor_ActOp_V_V_sigmoid = 106, + OP_FUNC_CGRATensor_ActOp_V_V_relu = 107, + OP_FUNC_CGRATensor_ActOp_V_V_satrelu = 108, + OP_FUNC_CGRATensor_ActOp_V_V_leakyrelu = 109, + OP_FUNC_CGRATensor_ActOp_V_V_softplus = 110, + + // Reduce Operator + OP_FUNC_CGRATensor_ReduceOp_T_T_sum = 111, + OP_FUNC_CGRATensor_ReduceOp_T_T_avg = 112, + OP_FUNC_CGRATensor_ReduceOp_T_T_max = 113, + OP_FUNC_CGRATensor_ReduceOp_T_T_min = 114, + + // Pool Operator + OP_FUNC_CGRATensor_PoolOp_T_T_avg = 115, + OP_FUNC_CGRATensor_PoolOp_T_T_sum = 116, + OP_FUNC_CGRATensor_PoolOp_T_T_max = 117, + OP_FUNC_CGRATensor_PoolOp_T_T_indexedmax = 118, + OP_FUNC_CGRATensor_PoolOp_T_T_min = 119, + OP_FUNC_CGRATensor_PoolOp_T_T_indexedmin = 120, + + // DataMove + OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool = 121, + OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool_avg = 122, + OP_FUNC_CGRATensor_DataMoveOp_T_T_maskunpool = 123, + // reshape + OP_FUNC_CGRATensor_DataMoveOp_T_T_mirror = 124, + OP_FUNC_CGRATensor_DataMoveOp_T_T_transpose = 125, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate90 = 126, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate180 = 127, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate270 = 128, + OP_FUNC_CGRATensor_DataMoveOp_T_T_nchw2nhwc = 129, + OP_FUNC_CGRATensor_DataMoveOp_T_T_nhwc2nchw = 130, + OP_FUNC_CGRATensor_DataMoveOp_T_T_concat = 131, + OP_FUNC_CGRATensor_DataMoveOp_T_T_pad = 132, + OP_FUNC_CGRATensor_DataMoveOp_T_T_channelnorm = 133, + // datamove + OP_FUNC_CGRATensor_DataMoveOp_V_V_maskmove = 134, + OP_FUNC_CGRATensor_DataMoveOp_T_T_gatherscatter = 135, + OP_FUNC_CGRATensor_DataMoveOp_V_V_maskgather = 136, + OP_FUNC_CGRATensor_DataMoveOp_V_bV_maskgather = 137, + OP_FUNC_CGRATensor_DataMoveOp_T_T_img2col = 138, + + // Conver Operator + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp16 = 139, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_bf16 = 140, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp32 = 141, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_tf32 = 142, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp16 = 143, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_bf16 = 144, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp32 = 145, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_tf32 = 146, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp16 = 147, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_bf16 = 148, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp32 = 149, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_tf32 = 150, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int8 = 151, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int16 = 152, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int32 = 153, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp16 = 154, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp32 = 155, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_tf32 = 156, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int8 = 157, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int16 = 158, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int32 = 159, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_bf16 = 160, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_fp32 = 161, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_tf32 = 162, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int8 = 163, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int16 = 164, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int32 = 165, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_fp16 = 166, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_bf16 = 167, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_tf32 = 168, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int8 = 169, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int16 = 170, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int32 = 171, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp16 = 172, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_bf16 = 173, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp32 = 174, + + // Peripheral Operator + OP_FUNC_CGRATensor_PeriOp_S_V_count = 175, + OP_FUNC_CGRATensor_PeriOp_S_bV_bitcount = 176, + OP_FUNC_CGRATensor_PeriOp_V_V_argmax = 177, + OP_FUNC_CGRATensor_PeriOp_V_V_argmin = 178, + OP_FUNC_CGRATensor_PeriOp_T_memset = 179, + OP_FUNC_CGRATensor_PeriOp_V_V_fp32_factorize = 180, + OP_FUNC_CGRATensor_PeriOp_V_V_bit2fp = 181, + OP_FUNC_CGRATensor_PeriOp_T_T_bilinear = 182, + OP_FUNC_CGRATensor_PeriOp_V_V_lut16 = 183, + OP_FUNC_CGRATensor_PeriOp_V_V_lut32 = 184, + OP_FUNC_CGRATensor_PeriOp_V_rand_gen = 185, + OP_FUNC_CGRATensor_PeriOp_V_V_elem_mask = 186, +} OP_FUNC_CGRA; + +typedef enum CGRA_INSTR_TYPE { + CGRA_INSTR_TYPE0, + CGRA_INSTR_TYPE1, + CGRA_INSTR_TYPE2, + CGRA_INSTR_TYPE3, +} CGRA_INSTR_TYPE; + +typedef struct Op_fu_head { + uint8_t fu; + uint8_t opcode; +} Op_fu_head; + +typedef struct FU_gemm_head { + uint8_t fu; + uint8_t gemm; +} FU_gemm_head; + +typedef struct opfunc_cgra_info { + char name[64]; // CGRATensor_ArithOp_V_V_abs + int32_t opcode; // 8'b0000_0000 + int32_t type; // CGRA_Tensor_type0 +} opfunc_cgra_info; + +// Neural +typedef enum Data_Format { + Fmt_INT8, + Fmt_INT16, + Fmt_FP16, + Fmt_BF16, + Fmt_INT32, + Fmt_FP32, + Fmt_TF32, + Fmt_BOOL, // 1/8 BYTE + Fmt_UINT8, + Fmt_UINT16, + Fmt_UINT32, + Fmt_INT64, + Fmt_UINT64, + Fmt_UNUSED, +} Data_Format; + +typedef enum Tensor_Fmt { + T_GemmM = 0, /*M K*/ + T_ConvA = 1, /*H W C*/ + T_ConvW = 2, /*Kx Ky F C*/ + T_Vec = 3, + T_ConvNA = 4, + T_ConvNW = 5, +} Tensor_Fmt; + +/* + 张量做SumReduce操作,支持以下维度: + C方向规约,结果为HW(C=1),dim=0 + W方向规约,结果为H(W=1)C,dim=1 + H方向规约,结果为(H=1)WC,dim=2 + HW方向规约,结果为(H=1)(W=1)C,dim=4 +*/ +typedef enum Reduce_Dim { + Reduce_C = 0, + Reduce_W = 1, + Reduce_H = 2, + Reduce_HW = 4, +} Reduce_Dim; + +typedef struct NCC_CSR { + uint64_t ib_status; //[7:0]IB_COUNTER: 指令buffer剩余指令数目, [8]TASK_DONE, 1:task执行结束, 0:task 正在执行, + //[63:9]Reserved + uint64_t exception; //[7:0]SCALAR_EXCEPTION, [15:8]CT_EXCEPTION, [23:16]NE_EXCEPTION, [31:24]RDMA_EXCEPTION, + //[39:32]WDMA_EXCEPTION, [47:40]TDMA_EXCEPTION, [63:48]Reserved + uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved + uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, [49]EXCEPTION_CLEAR, [63:49]Reserved + uint64_t serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved +} NCC_CSR; + +typedef struct EXCEP_SERI { + uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, [49]EXCEPTION_CLEAR, [63:49]Reserved + uint64_t serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved +} EXCEP_SERI; +#endif diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h new file mode 100644 index 000000000..c1bf4bb2f --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h @@ -0,0 +1,474 @@ +/* + * Copyright (C) 2024 Tsing Micro Intelligent Technology Co.,Ltd. All rights + * reserved. + * + * This file is the property of Tsing Micro Intelligent Technology Co.,Ltd. This + * file may only be distributed to: (i) a Tsing Micro party having a legitimate + * business need for the information contained herein, or (ii) a non-Tsing Micro + * party having a legitimate business need for the information contained herein. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + */ + + #ifndef __HOST_RUNTIME_COM_H__ + #define __HOST_RUNTIME_COM_H__ + + #include + #include + #include + #include + #include + #include + #include + #include + + #ifndef MAX_SHAPE_DIM + #define MAX_SHAPE_DIM 6 + #endif + + #ifndef MAX_MODEL_NUM + #define MAX_MODEL_NUM 32 + #endif + + typedef uint64_t TsmDevicePtr; + typedef uint64_t TsmHostPtr; + + #define CHIP_MAX_NUM 32 + #define TILE_MAX_NUM 16 + #define CACHE_ALIGN_4k 4096 + + typedef void *(*THREAD_PROC_FUNC)(void *); + + enum TSM_RETCODE { + RET_SUCCESS, + RET_ERROR, + RET_PARAM1_ERROR, + RET_PARAM2_ERROR, + RET_PARAM3_ERROR, + RET_DEVICE_OFFLINE, + RET_DEVICE_NOMEM, + RET_DEVICE_IN_IDLE, + RET_DEVICE_IN_ATTACH, + RET_DEVICE_ATTACH_SUCCESS, + RET_DEVICE_ATTACH_READY, + RET_DEVICE_LOSE_CONNECT, + RET_ENV_CLEAN_UP, + }; + + typedef enum HostLogLevel { LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL, LOG_MAX } HostLogLevel; + + typedef enum TsmModuleType { + TSM_RUNTIME, + TSM_XLA, // 前端 + TSM_TXNN, // 推理引擎 + TSM_ENGTEST, // 板端测试套件 + TSM_HOSTSIM, // 模拟器测试套件 + TSM_CMODEL, // 模拟器API + TSM_RT_TEST, // runtime组件测试套件 + } TsmModuleType; + + typedef enum TsmProfAction { TSM_PROF_START, TSM_PROF_STOP } TsmProfAction; + constexpr uint16_t PROF_TYPE_NCC = 0x1; + constexpr uint16_t PROF_TYPE_SPM = 0x2; + constexpr uint16_t PROF_TYPE_DTE = 0x4; + + typedef enum DTYPE { + FMT_INT8, + FMT_INT16, + FMT_FP16, + FMT_BF16, + FMT_INT32, + FMT_FP32, + FMT_TF32, + FMT_BOOL, // 1/8 BYTE + FMT_UINT8, + FMT_UINT16, + FMT_UINT32, + FMT_INT64, + FMT_UINT64, + FMT_UNUSED, + } DTYPE; + + uint8_t hrt_get_dtype_size(DTYPE dtype); + + enum DynDataType { + PKT_FINAL_TYPE = 0, + CFG_PMU_TYPE, + KCORE_CFG_TYPE, + EXPORT_SPM_TYPE, + DISABLE_CALC_TYPE, + PROF_CFG_TYPE, + DYNLIB_LOAD, + DYNLIB_RUN, + DYNLIB_UNLOAD, + MEMCPY_D2D, + P2P_SEND, + P2P_RECV, + DATA_TYPE_MAX, + }; + + typedef struct DynTLV_Terminate { + uint32_t type; // DynDataType + uint32_t len; + uint64_t is_final; + } DynTLV_Terminate; + + typedef struct DynTLV { + uint32_t type; // DynDataType + uint32_t len; + } DynTLV; + + typedef struct Cfg_Pmu_Info { + uint32_t tile_bitmap[16]; + uint32_t mac_use_rate; + uint32_t chip_id; + uint32_t cycles; + uint64_t in_ddr; + uint64_t param_ddr; + uint64_t out_ddr; + uint32_t reserved; + } Cfg_Pmu_Info; + + typedef struct DynTLV_Cfgpmu { + uint32_t type; // DynDataType + uint32_t len; + Cfg_Pmu_Info cfg_pmu; + } DynTLV_Cfgpmu; + + typedef struct DynTLV_KcoreCfg { + uint32_t type; + uint32_t len; + uint64_t snap_addr[TILE_MAX_NUM]; + uint64_t console_addr[TILE_MAX_NUM]; + uint64_t spm_dump_addr[TILE_MAX_NUM]; + uint64_t spm_dump_size; + uint32_t log_level; + uint32_t enable_monitor; + } DynTLV_KcoreCfg; + + typedef struct DynTLV_KcoreCalc { + uint32_t type; + uint32_t len; + uint32_t disable_kcore_calc; + } DynTLV_KcoreCalc; + + typedef struct DynTLV_ProfCfg { + uint32_t type; + uint32_t len; + uint64_t addrs[TILE_MAX_NUM]; + uint32_t size; + uint16_t enable; + uint16_t prof_type; + } DynTLV_ProfCfg; + + // #define TILE_NUM 16 + typedef struct DynModule { + char module_name[128]; + char module_symbol[128]; // typedef void (*entry_func_t)(voicd *): + uint32_t module_size[TILE_MAX_NUM]; + uint64_t module_addr[TILE_MAX_NUM]; // dev地址 + } DynModule; + + typedef struct DynMods { + uint16_t module_num; + struct DynModule modules[0]; + } DynMods; // host共用结构,传过来这个首地址 + + typedef struct DynTLV_DynMods { + uint32_t type; // DynDataType + uint32_t len; + uint64_t ext_addr; + uint64_t dyn_mods_addr; //指向DynMods + } DynTLV_DynMods; + + typedef struct TileDteCfg { + uint16_t status; // 该tile是否参与搬运工作 + uint16_t remote_tile_id; // 对端tile_id + uint32_t element_count; // 单次搬运cache_line大小,默认4k + uint32_t stride; // 步长 + uint32_t left_element_count; // 搬完cache_line后,剩余的搬运的长度 + uint64_t iteration; // 搬运cache_line的次数 + uint64_t src_addr; // 搬运cache_line的源地址 - 物理 + uint64_t dst_addr; // 搬运cache_line的目的地址 - 物理 + uint64_t left_src_addr; // 搬运余数的源地址 - 物理 + uint64_t left_dst_addr; // 搬运余数的目的地址 - 物理 + } TileDteCfg; + typedef struct DynTLV_DteCfg { + uint32_t type; + uint32_t len; + TileDteCfg tile_dte_cfg[TILE_MAX_NUM]; + uint64_t barrier_addr; + uint32_t row_card_num; + uint32_t reserved; + } DynTLV_DteCfg; + + enum Tensor_Type { + INPUT_DATA, + OUTPUT_DATA, + PARAM_DATA, + CHACHE_DATA, + DEV_DDR_DATA, + }; + + typedef struct tensor_info { + int32_t inplace; + uint32_t dim; + uint32_t dtype; + uint32_t layout; + uint32_t shape[MAX_SHAPE_DIM]; + } tensor_info_t; + + typedef struct Json_common_info_t { + uint32_t input_num; + uint32_t output_num; + uint32_t param_num; + uint32_t tile_num; + + std::string case_name; + std::string card_name; + + std::vector> input; + std::vector> output; + + std::vector input_file; + std::vector output_file; + std::vector param_file; + + std::vector input_size; + std::vector output_size; + std::vector param_size; + uint64_t imm_size; + + } Json_common_info_t; + + typedef struct chip_common_info { + uint32_t input_num; + uint32_t output_num; + uint32_t param_num; + uint32_t tile_num; + uint32_t tile_x; + uint32_t tile_y; + std::vector> input; + std::vector> output; + + // char card_name[100]; + std::string card_name; + std::vector input_file; + std::vector output_file; + std::vector output_ref_file; + std::vector param_file; + + std::vector input_size; + std::vector output_size; + std::vector param_size; + + std::vector input_host_addr; + std::vector input_dev_addr; + std::vector output_host_addr; + std::vector output_dev_addr; + std::vector param_host_addr; + std::vector param_dev_addr; + + uint64_t imm_size; + } chip_common_info_t; + + typedef struct json_common_info_multi_card { + uint32_t chip_num; + uint32_t chip_x; + uint32_t chip_y; + std::string case_name; + uint32_t loop_num; + std::vector> chip_infos; + } json_common_info_multi_card_t; + + typedef struct CompileOption { + bool comp_enable = false; + std::string rtt_tool_path; + std::string compile_path; + bool check_enable = false; + uint32_t chip_x; + uint32_t chip_y; + bool enable_kcore_bin; + bool enable_kcore_so; + } CompileOption; + + // Boot Param Table + typedef struct BootParamHead { + uint32_t MaxLen; // BootParamHead + n * BootParamDyninfo, n = inputnum + outputnum + paramnum + uint32_t LdmemLen; + uint32_t InputNum; + uint32_t OutputNum; + uint32_t ParamNum; + uint32_t reserved; + uint64_t CacheMemLen; + uint64_t CacheMemAddr; // device + uint32_t Datalen; + uint32_t reserved1; + uint64_t DataAddr; // device + } BootParamHead; + + typedef struct BootParamDyninfo { + uint64_t addr; // device + uint64_t size; + uint32_t dtype; + uint32_t dim; + uint32_t shape[6]; //#define MAX_SHAPE_DIM 6 //n, h, w, c, x, x + } BootParamDyninfo; + + class HrtBootParam { + public: + HrtBootParam(uint32_t i_num, uint32_t o_num, uint32_t p_num) : i_num(i_num), o_num(o_num), p_num(p_num) { + uint32_t bufsize = (sizeof(BootParamHead) + (i_num + o_num + 1) * sizeof(BootParamDyninfo)); + buffer = (void *)malloc(bufsize); + memset(buffer, 0, bufsize); + BootParamHead *head = static_cast(buffer); + head->MaxLen = bufsize; + head->LdmemLen = 0x200000; + head->InputNum = i_num; + head->OutputNum = o_num; + head->ParamNum = p_num; + } + ~HrtBootParam() { + if (buffer != nullptr) { + free(buffer); + } + } + std::vector dyninfo; + uint32_t get_maxlen(); + void *get_bootpmbuffer(); + BootParamHead *get_headptr(); + BootParamDyninfo *get_inputptr(uint32_t index); + BootParamDyninfo *get_outputptr(uint32_t index); + BootParamDyninfo *get_paramptr(uint32_t index); + void set_dev_cache(uint64_t dev_addr, uint64_t size); + void set_dev_cache_mem_addr(uint64_t dev_addr, uint64_t size); + void set_dev_dyndata(uint64_t dev_addr, uint32_t size); + void set_dev_dyndata_mem_addr(uint64_t dev_addr, uint32_t size); + void set_dev_input(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_input_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_input_tensor(uint32_t idx, std::shared_ptr tensor); + void set_dev_output(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_output_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_param(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_param_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + std::shared_ptr get_dev_output_tensor_after_run(uint32_t idx); + + private: + uint32_t i_num; + uint32_t o_num; + uint32_t p_num; + void *buffer; + }; + /* 启动参数end */ + + /* compiler后生成的存储elf和param地址的对象 */ + class HostParamElem { + public: + HostParamElem() : dataPtr(nullptr), size(0) {} + ~HostParamElem(); + //模拟器:从文件中加载一个bin + HostParamElem(const std::string &filepath); + + uint8_t *loadBinaryFile(const std::string filepath, uint64_t &fsize); + uint8_t *dataPtr; // host + uint64_t size; // byte + }; + + class ChipModelInfo { + public: + ChipModelInfo(); + ChipModelInfo(uint32_t id); + ~ChipModelInfo(); + + uint32_t getChipId() { return chip_id; } + // support multi chip + std::vector> elfs; //编译出的elf文件 + std::vector> bins; //编译出的bin文件 + std::vector> params; + + private: + uint32_t chip_id; + }; + + /* + * compiler后生成的模型对象,在launch的时候会将elf/bin的指针传入soc的接口, + * (PCIE搬运时,如果空间不连续会触发多次搬运,因此交由SOC组装连续空间。) + */ + class TsmModel { + public: + TsmModel(); // org_model + ~TsmModel(); + TsmModel(const std::string &filepath); + + std::vector> chip_infos; + THREAD_PROC_FUNC proc_func; + std::string case_name; + std::string case_dir; + std::shared_ptr so_list[MAX_MODEL_NUM][TILE_MAX_NUM]; // 编译出的so文件 + std::string module_name; + struct txmodel *model[MAX_MODEL_NUM]; + }; + + typedef struct TsmDevice { + char res_path[128]; + uint32_t chip_id; + uint32_t tile_num = 16; + void *soc_device; + } TsmDevice_t; + + class TsmTensorData { + public: + TsmTensorData() : host_addr(0), device_addr(0), length(0) {} + ~TsmTensorData(){}; + + TsmHostPtr host_addr; + TsmDevicePtr device_addr; + uint64_t length; + uint32_t data_type; + Tensor_Type tensor_type; + }; + + typedef void *tsmStream_t; + typedef void *tsmEvent_t; + typedef struct txcclComm* txcclComm_t; + typedef enum { + txcclDataDefault = 0 + } txcclDataType_t; // 预留,待讨论 + + enum device_status { + FULLGOOD = 0, + PARTIALGOOD = 1, + }; + + constexpr uint32_t PARTIALGOOD_NUM = 8; + constexpr uint32_t FULLGOOD_NUM = 16; + + struct CardComputeInfo { + uint32_t card_id; + enum device_status device_status; + uint32_t all_tile_num; + double all_tile_compute; + uint32_t left_tile_num; + double left_tile_compute; + }; + + struct TsmDeviceInfo { + uint32_t card_num; + uint32_t card_x; + uint32_t card_y; + CardComputeInfo card_compute_info[CHIP_MAX_NUM]; + }; + + int32_t readDataFromFile(uint8_t *buffer, std::string file, uint32_t size); + uint8_t *read_file_data(std::string file, uint64_t &size); + + std::shared_ptr get_multi_card_common_info_from_file(std::string file); + std::string get_docker_verison(); + TSM_RETCODE set_multi_graph(TsmModel *&kmodel, std::shared_ptr &hostboot, + const TsmDevicePtr &dev_dyn_mods_ptr, const TsmDevicePtr &dev_tlv_ptr, TsmDevicePtr ext_ptr); + #endif /* __HOST_RUNTIME_COM_H__ */ diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h new file mode 100644 index 000000000..6c0fdd07d --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2024 Tsing Micro Intelligent Technology Co.,Ltd. All rights + * reserved. + * + * This file is the property of Tsing Micro Intelligent Technology Co.,Ltd. This + * file may only be distributed to: (i) a Tsing Micro party having a legitimate + * business need for the information contained herein, or (ii) a non-Tsing Micro + * party having a legitimate business need for the information contained herein. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + */ + +#ifndef __HOST_RUNTIME_INTERFACE_H__ +#define __HOST_RUNTIME_INTERFACE_H__ + +#include +#include + +#include "hrt_common.h" + +/* + * 以下接口依赖Runtime实例生命周期中,即调用TsmInitRuntime后,调用TsmDeInitRuntime前 + */ +TSM_RETCODE TsmInitRuntime(void); +TSM_RETCODE TsmDeInitRuntime(void); +TSM_RETCODE TsmDeInitRuntimeLegacy(void); +TSM_RETCODE TsmSetDevice(uint32_t first_phy_id, uint32_t card_x, uint32_t card_y, std::vector &devs); +TSM_RETCODE TsmSetDeviceOld(uint32_t chip_id, TsmDevice *dev); /* 该接口为提供给MLIR的过度版本,其他组件不要调用 */ +TSM_RETCODE TsmDeviceMalloc(TsmDevice *dev, TsmDevicePtr &ptr, uint64_t size); +TSM_RETCODE TsmDeviceMemset(TsmDevicePtr &ptr, uint32_t ch, uint64_t size); +TSM_RETCODE TsmDeviceFree(TsmDevicePtr ptr); +TSM_RETCODE TsmDeviceSynchronize(TsmDevice *dev); +TSM_RETCODE TsmInitDevice(TsmDevice *dev); +TSM_RETCODE TsmCompile(std::vector devs, TsmModel &kmodel, std::string option, CompileOption compl_op); +TSM_RETCODE TsmCompileMultiGraph(std::vector devs, TsmModel &kmodel, std::string option, + CompileOption compl_op); +TSM_RETCODE TsmLaunch(TsmDevice *dev, TsmModel &kmodel); +TSM_RETCODE TsmLoadKernel(TsmDevice *dev, std::vector &kmodel_vec, char *module_symbol); +TSM_RETCODE TsmUnloadKernel(TsmDevice *dev, std::vector &kmodel_vec); +TSM_RETCODE TsmRun(TsmDevice *dev, TsmDevicePtr bootpm_dev); +TSM_RETCODE TsmAsyncRun(tsmStream_t stream, TsmDevice *dev, TsmDevicePtr bootpm_dev); +TSM_RETCODE TsmSetTerminate(TsmDevice *dev, tsmStream_t stream = nullptr); +TSM_RETCODE TsmGetDeviceInfo(TsmDeviceInfo *info); +TSM_RETCODE TsmTerminate(TsmDevice *dev, TsmDevicePtr bootpm_dev); +TSM_RETCODE TsmMemcpyH2D(TsmDevicePtr dst, const void *src, uint64_t byte_count); +TSM_RETCODE TsmMemcpyD2H(const void *dst, TsmDevicePtr src, uint64_t byte_count); +TSM_RETCODE TsmMemcpyOffsetH2D(TsmDevicePtr dst, const void *src, uint64_t offset, uint64_t byte_count); +TSM_RETCODE TsmMemcpyOffsetD2H(const void *dst, TsmDevicePtr src, uint64_t offset, uint64_t byte_count); +TSM_RETCODE TsmMemcpyD2D(const void *dst, TsmDevice *dst_dev, const void *src, TsmDevice *src_dev, uint64_t byte_count); +TSM_RETCODE TsmSend(const void* sendbuff, size_t count, txcclDataType_t datatype, TsmDevice *dev, int peer, txcclComm_t comm, tsmStream_t stream); +TSM_RETCODE TsmRecv(void* recvbuff, size_t count, txcclDataType_t datatype, TsmDevice *dev, int peer, txcclComm_t comm, tsmStream_t stream); +TSM_RETCODE TsmResetDevice(TsmDevice *dev); +TSM_RETCODE TsmReleaseDevice(TsmDevice *dev); +TSM_RETCODE TsmMemGetInfo(TsmDevicePtr ptr, uint32_t &card_id, uint64_t &addr, uint64_t &size); +TSM_RETCODE TsmEventCreate(tsmEvent_t *pEvent); +TSM_RETCODE TsmEventDestroy(tsmEvent_t event); +TSM_RETCODE TsmEventRecord(tsmEvent_t event, tsmStream_t stream); +TSM_RETCODE TsmEventWait(tsmEvent_t event, tsmStream_t stream); +TSM_RETCODE TsmStreamCreate(tsmStream_t *pStream, TsmDevice *dev); +TSM_RETCODE TsmStreamSynchronize(tsmStream_t stream); +TSM_RETCODE TsmStreamDestroy(tsmStream_t stream); +TSM_RETCODE TsmDeviceSerialize(const TsmDevice *const &dev, void *&buffer, size_t &size); +TSM_RETCODE TsmDeviceDeSerialize(TsmDevice *&dev, const void *const &buffer); +TSM_RETCODE TsmSetMonitorInfo(TsmDevice *dev); +TSM_RETCODE TsmProcessProfData(TsmDevice *dev, TsmProfAction prof_action, uint16_t prof_type); +TSM_RETCODE TsmHostH2D(TsmDevice *dev, uint64_t input_host_addr, uint64_t input_size, int32_t index); +TSM_RETCODE TsmHostFlush(TsmDevice *dev, uint64_t boot_param_ptr, uint8_t *host_buffer, size_t size); +TSM_RETCODE TsmSetRankSize(uint32_t x_size, uint32_t y_size); +TSM_RETCODE TsmSetRankId(uint32_t x, uint32_t y); +TSM_RETCODE TsmGetPhyRankId(uint32_t *x, uint32_t *y); + +/* + * 以下接口为无状态,不依赖Runtime实例,可以独立使用 + */ +TSM_RETCODE TsmGetDeviceNum(uint32_t &dev_num); + +/* + * 为保持Host日志格式统一,Runtime提供了统一日志接口,各组件按以下方式使用: + * #define rt_log(level, format, ...) tsm_log(__FILE__, __func__, __LINE__, TSM_RUNTIME, level, format, ##__VA_ARGS__) + * + * void func() { + * rt_log(LOG_DEBUG, "....\n"); + * rt_log(LOG_INFO, "....\n"); + * rt_log(LOG_WARNING, "....\n"); + * rt_log(LOG_ERROR, "....\n"); + * } + * 默认日志级别为INFO,通过设置 HOST_LOG_LEVEL 更改日志级别,一般就设置成INFO和DEBUG。 + * 注意: + * 其中rt_log为各组件定制名称,切勿重复,TSM_RUNTIME表示模块ID,各模块到hrt_common.h找到自己的宏,没有的可以联系runtime来增加。 + */ +void tsm_log(const char *file_name, const char *func_name, uint32_t line_number, TsmModuleType module_type, + HostLogLevel level, const char *format, ...); +#endif \ No newline at end of file diff --git a/third_party/tsingmicro/crt/include/Tx81/tx81.h b/third_party/tsingmicro/crt/include/Tx81/tx81.h new file mode 100644 index 000000000..6176349d2 --- /dev/null +++ b/third_party/tsingmicro/crt/include/Tx81/tx81.h @@ -0,0 +1,22 @@ +//===----------------------- tx81.h ---------------------------*- C -*-----===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef CRT_TARGET_TX81_H +#define CRT_TARGET_TX81_H + +#include "instr_adapter.h" +#include "instr_def.h" +#include +#include +#include + +enum MemorySpace : int32_t { + UNKNOWN = 0, + SPM = 1, + DDR = 2, +}; + +#endif // CRT_TARGET_TX81_H diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmax.c b/third_party/tsingmicro/crt/lib/Tx81/argmax.c new file mode 100644 index 000000000..412e9fa24 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/argmax.c @@ -0,0 +1,26 @@ +//===------------------------ argmax.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArgMax see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __ArgMax(uint64_t *src, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->ArgMax(&inst, (uint64_t) src, elem_count, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmin.c b/third_party/tsingmicro/crt/lib/Tx81/argmin.c new file mode 100644 index 000000000..e79223b5a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/argmin.c @@ -0,0 +1,26 @@ +//===------------------------ argmin.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArgMin see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __ArgMin(uint64_t *src, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->ArgMin(&inst, (uint64_t) src, elem_count, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/arith.c b/third_party/tsingmicro/crt/lib/Tx81/arith.c new file mode 100644 index 000000000..4fc7120ee --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/arith.c @@ -0,0 +1,165 @@ +//===------------------------ arith.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::ArithOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AddVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->AddVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->SubVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MulVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->MulVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + + +void __DivVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->DivVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + round, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __AddVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AddVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __SubVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->SubVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MulVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MulVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __DivVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->DivVS(&inst, (uint64_t)src0, src1, (uint64_t)dst, elem_count, round, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c new file mode 100644 index 000000000..68df7b31e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c @@ -0,0 +1,25 @@ +//===------------------------ bf16_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_FP16(&inst, (uint64_t) src, (uint64_t) dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c new file mode 100644 index 000000000..c3063cdcf --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c @@ -0,0 +1,25 @@ +//===------------------------ bf16_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c new file mode 100644 index 000000000..a42e3614d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c @@ -0,0 +1,26 @@ +//===------------------------ bf16_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c new file mode 100644 index 000000000..4af892a73 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c @@ -0,0 +1,26 @@ +//===------------------------ bf16_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c new file mode 100644 index 000000000..4286f22fa --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c @@ -0,0 +1,25 @@ +//===------------------------ bf16_int8.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c new file mode 100644 index 000000000..9220132cc --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c @@ -0,0 +1,25 @@ +//===------------------------ bf16_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::BF16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BF16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->BF16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c new file mode 100644 index 000000000..23ff84532 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c @@ -0,0 +1,32 @@ +//===------------------------ bilinear.c ----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Bilinear see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Bilinear(uint64_t *src, uint64_t *dst, uint16_t src_n, uint16_t src_h, + uint16_t src_w, uint16_t src_c, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Bilinear(&inst, (uint64_t) src, (uint64_t) dst, shape1, shape2, + (src_w - 1) / (dst_w - 1), (src_h - 1) / (dst_h - 1), + (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c new file mode 100644 index 000000000..7b818b1df --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c @@ -0,0 +1,26 @@ +//===------------------------ bit2fp.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Bit2Fp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Bit2Fp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/common.c b/third_party/tsingmicro/crt/lib/Tx81/common.c new file mode 100644 index 000000000..680e5a367 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/common.c @@ -0,0 +1,25 @@ +//===----------------------- common.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Implement common helper functions in this file. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// WORKAROUND for undefined symbols in libkcorert.a +int main(int argc, char** argv) { + return 0; +} + +int get_app_version() { + return 1; +} + +int nvram_get_val() { + return 1; +} \ No newline at end of file diff --git a/third_party/tsingmicro/crt/lib/Tx81/concat.c b/third_party/tsingmicro/crt/lib/Tx81/concat.c new file mode 100644 index 000000000..a7cc0da55 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/concat.c @@ -0,0 +1,34 @@ +//===------------------------ concat.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Concat see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Concat(uint64_t *src1, uint16_t src1_n, uint16_t src1_h, uint16_t src1_w, + uint16_t src1_c, uint64_t *src2, uint16_t src2_n, uint16_t src2_h, + uint16_t src2_w, uint16_t src2_c, uint64_t *dst, uint16_t dst_n, + uint16_t dst_h, uint16_t dst_w, uint16_t dst_c, uint32_t dim, + uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src1_n, src1_h, src1_w, src1_c }; + Data_Shape shape2 = { src2_n, src2_h, src2_w, src2_c }; + Data_Shape shape3 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Concat(&inst, (uint64_t)src1, shape1, (uint64_t)src2, shape2, + (uint64_t) dst, shape3, dim, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/conv.c b/third_party/tsingmicro/crt/lib/Tx81/conv.c new file mode 100644 index 000000000..3f05b2ee8 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/conv.c @@ -0,0 +1,61 @@ +//===------------------------ conv.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmConv, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Conv(int64_t opType, int64_t* srcAct, int64_t* srcActDims, int64_t* weight, + int64_t* weightDims, bool enBias, int64_t* bias, bool enNegScale, + int64_t* negScale, bool enPosScale, int64_t* posScale, bool enSparse, + int64_t* sparse, bool enPsum, int64_t* psum, int64_t* pads, + int64_t* unpads, int64_t* strides, int64_t* dilations, + bool enLeakyRelu, int64_t srcActFmt, int64_t weightFmt, int64_t dstFmt, + int64_t* dst, int64_t* dstDims) +{ + // Create convolution command buffer. + TsmConv *conv = TsmNewConv(); + TsmNeInstr inst = {I_NEUR, {0,}, {0,}}; + + // Convert to nhwc format + Data_Shape shape = {(uint16_t)srcActDims[0], (uint16_t)srcActDims[1], + (uint16_t)srcActDims[2], (uint16_t)srcActDims[3]}; + + Data_Shape wshape = {(uint16_t)weightDims[0], (uint16_t)weightDims[1], + (uint16_t)weightDims[2], (uint16_t)weightDims[3]}; + + Data_Shape dstShape = {(uint16_t)dstDims[0], (uint16_t)dstDims[1], + (uint16_t)dstDims[2], (uint16_t)dstDims[3]}; + + conv->AddInput(&inst, (int64_t) srcAct, shape, (Data_Format)srcActFmt); + conv->AddWeight(&inst, (uint64_t) weight, wshape, (Data_Format)weightFmt); + conv->AddBias(&inst, enBias, (uint64_t) bias); + conv->AddOutput(&inst, (uint64_t) dst, dstShape, (Data_Format)dstFmt); + conv->SetOpType(&inst, opType); + conv->SetNegativeAxisScale(&inst, enNegScale, (uint64_t) negScale); + conv->SetPositiveAxisScale(&inst, enPosScale, (uint64_t) posScale); + conv->SetSparse(&inst, enSparse, (uint64_t) sparse); + // FIXME: Should we have psum format instead? + conv->SetPsum(&inst, enPsum, (uint64_t) psum, (Data_Format) dstFmt); + conv->SetPads(&inst, pads[0], pads[1], pads[2], pads[3]); + conv->SetUnPads(&inst, unpads[0], unpads[1], unpads[2], unpads[3]); + conv->SetKernelStrides(&inst, strides[0], strides[1], strides[2], strides[3]); + conv->SetDilations(&inst, dilations[0], dilations[1]); + if (enLeakyRelu) + conv->EnableLeakyRelu(&inst); + else + conv->EnableRelu(&inst); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConv(conv); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/cos.c b/third_party/tsingmicro/crt/lib/Tx81/cos.c new file mode 100644 index 000000000..23c965a2e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/cos.c @@ -0,0 +1,26 @@ +//===------------------------ cos.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Cos see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Cos(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Cos(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format) fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/count.c b/third_party/tsingmicro/crt/lib/Tx81/count.c new file mode 100644 index 000000000..9582070f7 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/count.c @@ -0,0 +1,28 @@ +//===------------------------ count.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Count see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Count(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt, + uint64_t *p_wb_data0, uint64_t *p_wb_data1) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->Count(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt, p_wb_data0, + p_wb_data1); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/exp.c b/third_party/tsingmicro/crt/lib/Tx81/exp.c new file mode 100644 index 000000000..472df41b0 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/exp.c @@ -0,0 +1,26 @@ +//===------------------------ exp.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Exp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Exp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Exp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/explp.c b/third_party/tsingmicro/crt/lib/Tx81/explp.c new file mode 100644 index 000000000..c88d2ae68 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/explp.c @@ -0,0 +1,26 @@ +//===------------------------ explp.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Explp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Explp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Explp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c new file mode 100644 index 000000000..a5ba62869 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c @@ -0,0 +1,26 @@ +//===------------------------ fp16_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c new file mode 100644 index 000000000..2c6732f0f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c @@ -0,0 +1,25 @@ +//===------------------------ fp16_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c new file mode 100644 index 000000000..5a161b4e9 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c @@ -0,0 +1,26 @@ +//===------------------------ fp16_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c new file mode 100644 index 000000000..cd64ac37a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c @@ -0,0 +1,26 @@ +//===------------------------ fp16_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c new file mode 100644 index 000000000..4d1356aa2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c @@ -0,0 +1,26 @@ +//===------------------------ fp16_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c new file mode 100644 index 000000000..d2d3163ad --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c @@ -0,0 +1,25 @@ +//===------------------------ fp16_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c new file mode 100644 index 000000000..0cdf0c995 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c new file mode 100644 index 000000000..ffd647e35 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_fp16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c new file mode 100644 index 000000000..95201fa00 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c new file mode 100644 index 000000000..bdaa4a9ad --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c new file mode 100644 index 000000000..b82017ae0 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c new file mode 100644 index 000000000..8eec01b48 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c @@ -0,0 +1,26 @@ +//===------------------------ fp32_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::FP32_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __FP32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->FP32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c new file mode 100644 index 000000000..ef4e72c95 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c @@ -0,0 +1,33 @@ +//===------------------------ gatherscatter.c -----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::GatherScatter see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, uint32_t src_s0, + uint32_t src_i0, uint32_t src_s1, uint32_t src_i1, + uint32_t src_s2, uint32_t src_i2, uint32_t dst_s0, + uint32_t dst_i0, uint32_t dst_s1, uint32_t dst_i1, + uint32_t dst_s2, uint32_t dst_i2) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + St_StrideIteration src_si = {src_s0, src_i0, src_s1, src_i1, src_s2, src_i2}; + St_StrideIteration dst_si = {dst_s0, dst_i0, dst_s1, dst_i1, dst_s2, dst_i2}; + + cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, size, &src_si, &dst_si); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/gemm.c b/third_party/tsingmicro/crt/lib/Tx81/gemm.c new file mode 100644 index 000000000..09b8c18ff --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/gemm.c @@ -0,0 +1,47 @@ +//===------------------------ gemm.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmGemm, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Gemm(int64_t* srcA, int64_t *srcB, int64_t * srcBias, int64_t *zeros, + int64_t *dims, bool enPsum, int64_t *psum, bool enTransA, bool enTransB, + int64_t batchSizeA, int64_t batchSizeB, bool enLeakyRelu, bool enBias, + bool enNegScale, int64_t *negScale, bool enPosScale, int64_t *posScale, + int64_t srcFmt, int64_t dstFmt, int64_t* dst) +{ + // Create gemm command buffer. + TsmGemm *gemm = TsmNewGemm(); + TsmNeInstr inst = {I_NEUR, {0,}, {0,}}; + + gemm->AddInput(&inst, (uint64_t) srcA, (uint64_t) srcB, (Data_Format) srcFmt); + gemm->ConfigMKN(&inst, (uint32_t) dims[0], (uint32_t) dims[1], + (uint32_t) dims[2]); + gemm->AddOutput(&inst, (uint64_t) dst, (Data_Format) dstFmt); + gemm->SetPsum(&inst, enPsum, (uint64_t) psum, (Data_Format) dstFmt); + gemm->SetTransflag(&inst, (uint8_t) enTransA, (uint8_t) enTransB); + // TODO: + // gemm->SetQuant(); + gemm->ConfigBatch(&inst, (uint32_t) batchSizeA, (uint32_t) batchSizeB); + gemm->AddBias(&inst, enBias, (uint64_t) srcBias); + gemm->SetNegativeAxisScale(&inst, enNegScale, (uint64_t) negScale); + gemm->SetPositiveAxisScale(&inst, enPosScale, (uint64_t) posScale); + if (enLeakyRelu) + gemm->EnableLeakyRelu(&inst); + else + gemm->EnableRelu(&inst); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteGemm(gemm); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/img2col.c b/third_party/tsingmicro/crt/lib/Tx81/img2col.c new file mode 100644 index 000000000..3d2bba633 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/img2col.c @@ -0,0 +1,36 @@ +//===------------------------ img2col.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Img2col see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Img2col(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint64_t src_elem_num, + uint64_t dst_elem_num, uint16_t swr_n, uint16_t swr_h, + uint16_t swr_w, uint16_t swr_c, uint16_t pdr_n, uint16_t pdr_h, + uint16_t pdr_w, uint16_t pdr_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape3 = {swr_n, swr_h, swr_w, swr_c}; + Data_Shape shape4 = {pdr_n, pdr_h, pdr_w, pdr_c}; + cmd->Img2col(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + src_elem_num, dst_elem_num, shape3, shape4, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c new file mode 100644 index 000000000..f06d8bec9 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c @@ -0,0 +1,26 @@ +//===------------------------ int16_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c new file mode 100644 index 000000000..0486a4834 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c @@ -0,0 +1,25 @@ +//===------------------------ int16_fp16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT16_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c new file mode 100644 index 000000000..670f3a9f5 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c @@ -0,0 +1,26 @@ +//===------------------------ int16_fp32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c new file mode 100644 index 000000000..61022964b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c @@ -0,0 +1,26 @@ +//===------------------------ int16_tf32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT16_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c new file mode 100644 index 000000000..261140c0f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c @@ -0,0 +1,26 @@ +//===------------------------ int32_bf16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c new file mode 100644 index 000000000..30169b1ed --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c @@ -0,0 +1,27 @@ +//===------------------------ int32_fp16.cpp +//-------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c new file mode 100644 index 000000000..b56cb3821 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c @@ -0,0 +1,26 @@ +//===------------------------ int32_fp32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c new file mode 100644 index 000000000..f0dc2c69a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c @@ -0,0 +1,26 @@ +//===------------------------ int32_tf32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT32_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c new file mode 100644 index 000000000..5f2253093 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c @@ -0,0 +1,25 @@ +//===------------------------ int8_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_BF16(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT8_BF16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c new file mode 100644 index 000000000..9166fa050 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c @@ -0,0 +1,25 @@ +//===------------------------ int8_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_FP16(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT8_FP16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c new file mode 100644 index 000000000..853ddd14a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c @@ -0,0 +1,25 @@ +//===------------------------ int8_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_FP32(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT8_FP32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c new file mode 100644 index 000000000..7fe060ab4 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c @@ -0,0 +1,25 @@ +//===------------------------ int8_tf32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::INT8_TF32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __INT8_TF32(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->INT8_TF32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c new file mode 100644 index 000000000..2e59054b2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c @@ -0,0 +1,27 @@ +//===------------------------ leakyrelu.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Leakyrelu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Leakyrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Leakyrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/ln.c b/third_party/tsingmicro/crt/lib/Tx81/ln.c new file mode 100644 index 000000000..41c528316 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/ln.c @@ -0,0 +1,26 @@ +//===------------------------ ln.c ----------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Ln see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Ln(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Ln(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/log2.c b/third_party/tsingmicro/crt/lib/Tx81/log2.c new file mode 100644 index 000000000..b012fdc11 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/log2.c @@ -0,0 +1,26 @@ +//===------------------------ log2.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Log2 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Log2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Log2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut16.c b/third_party/tsingmicro/crt/lib/Tx81/lut16.c new file mode 100644 index 000000000..eea3a71d8 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/lut16.c @@ -0,0 +1,27 @@ +//===------------------------ lut16.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Lut16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Lut16(uint64_t *src, uint64_t *dst, uint64_t *lut16, + uint32_t src_elem_count, uint32_t lut_elem_count) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->Lut16(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut16, src_elem_count, lut_elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut32.c b/third_party/tsingmicro/crt/lib/Tx81/lut32.c new file mode 100644 index 000000000..c3ca38f23 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/lut32.c @@ -0,0 +1,27 @@ +//===------------------------ lut32.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Lut32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Lut32(uint64_t *src, uint64_t *dst, uint64_t *lut32, + uint32_t src_elem_count, uint32_t lut_elem_count) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->Lut32(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut32, src_elem_count, lut_elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c new file mode 100644 index 000000000..cbeef4726 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c @@ -0,0 +1,25 @@ +//===------------------------ mask_move.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::MaskMoveOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __MaskMove(uint64_t *src, uint64_t *target, uint32_t elem_count, + uint64_t * mask, int32_t fmt) { + TsmMaskDataMove *move = TsmNewMaskDataMove(); + TsmMaskDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + move->MaskMove(&inst, (uint64_t)src, (uint64_t)mask, (uint64_t)target, + elem_count, (Data_Format)fmt); + + TsmExecute(&inst); + + TsmDeleteMaskDataMove(move); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/memset.c b/third_party/tsingmicro/crt/lib/Tx81/memset.c new file mode 100644 index 000000000..7c412e496 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/memset.c @@ -0,0 +1,29 @@ +//===------------------------ memset.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Memset see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Memset(uint64_t *dst, uint32_t value, uint32_t elem_count, uint32_t s0, + uint32_t i0, uint32_t s1, uint32_t i1, uint32_t s2, uint32_t i2, + uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + St_StrideIteration si = {s0, i0, s1, i1, s2, i2}; + cmd->Memset(&inst, (uint64_t)dst, value, elem_count, &si, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/mirror.c b/third_party/tsingmicro/crt/lib/Tx81/mirror.c new file mode 100644 index 000000000..db113e5fc --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/mirror.c @@ -0,0 +1,31 @@ +//===------------------------ mirror.c ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Mirror see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Mirror(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Mirror(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c new file mode 100644 index 000000000..cfca054f7 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c @@ -0,0 +1,31 @@ +//===------------------------ nchw2nhwc.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Nchw2nhwc see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Nchw2nhwc(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Nchw2nhwc(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c new file mode 100644 index 000000000..e9ccdcf0e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c @@ -0,0 +1,31 @@ +//===------------------------ nhwc2nchw.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Nhwc2nchw see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Nhwc2nchw(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Nhwc2nchw(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/pad.c b/third_party/tsingmicro/crt/lib/Tx81/pad.c new file mode 100644 index 000000000..e8d3caf72 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/pad.c @@ -0,0 +1,33 @@ +//===------------------------ pad.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Pad see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Pad(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t pad_n, uint16_t pad_h, + uint16_t pad_w, uint16_t pad_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + Data_Shape shape3 = { pad_n, pad_h, pad_w, pad_c }; + cmd->Pad(&inst, (uint64_t) src, shape1, (uint64_t) dst, + shape2, shape3, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/pow2.c b/third_party/tsingmicro/crt/lib/Tx81/pow2.c new file mode 100644 index 000000000..060edf08c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/pow2.c @@ -0,0 +1,26 @@ +//===------------------------ pow2.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Pow2 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Pow2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Pow2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/randgen.c b/third_party/tsingmicro/crt/lib/Tx81/randgen.c new file mode 100644 index 000000000..d390a2e6f --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/randgen.c @@ -0,0 +1,28 @@ +//===------------------------ randgen.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RandGen see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __RandGen(uint64_t *src0, uint64_t *src1, uint64_t *dst0, uint64_t *dst1, + uint64_t *dst2, uint32_t src_elem_num, uint16_t fmt) { + // Create command buffer. + TsmPeripheral *cmd = TsmNewPeripheral(); + TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + + cmd->RandGen(&inst, *src0, *src1, *dst0, *dst1, *dst2, src_elem_num, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeletePeripheral(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rdma.c b/third_party/tsingmicro/crt/lib/Tx81/rdma.c new file mode 100644 index 000000000..a72f964c2 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rdma.c @@ -0,0 +1,41 @@ +//===------------------------ rdma.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rdma, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, + int shape_c, int stride_n, int stride_h, int stride_w, + uint32_t fmt) { + // Create gemm command buffer. + TsmRdma *rdma = TsmNewRdma(); + TsmRdmaInstr inst = {I_RDMA, + { + 0, + }, + { + 0, + }}; + + rdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + + rdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, + shape_h, stride_n, shape_n); + + // rdma->Rdma1d(&inst, (uint64_t)src, (uint64_t)dst, shape_c, + // (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRdma(rdma); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/reduce.c b/third_party/tsingmicro/crt/lib/Tx81/reduce.c new file mode 100644 index 000000000..ebaeceb8e --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/reduce.c @@ -0,0 +1,103 @@ +//===---------------------- reduce.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TsmReduce, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __ReduceSum(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceSum(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceAvg(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceAvg(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceMax(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceMax(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} + +void __ReduceMin(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) { + // Create reduce command buffer. + TsmReduce *cmd = TsmNewReduce(); + TsmReduceInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + // TODO + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + cmd->ReduceMin(&inst, (uint64_t)src, (uint64_t)dst, dim, shape1, + (Data_Format)fmt); + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteReduce(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/relu.c b/third_party/tsingmicro/crt/lib/Tx81/relu.c new file mode 100644 index 000000000..90bddef53 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/relu.c @@ -0,0 +1,26 @@ +//===------------------------ relu.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Relu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Relu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Relu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c new file mode 100644 index 000000000..b5d78ae0b --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c @@ -0,0 +1,31 @@ +//===------------------------ rotate180.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate180 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate180(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Rotate180(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c new file mode 100644 index 000000000..f82561830 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c @@ -0,0 +1,31 @@ +//===------------------------ rotate270.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate270 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate270(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Rotate270(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c new file mode 100644 index 000000000..9e8480470 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c @@ -0,0 +1,30 @@ +//===------------------------ rotate90.c ----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Rotate90 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Rotate90(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Rotate90(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c new file mode 100644 index 000000000..338f2852d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c @@ -0,0 +1,26 @@ +//===------------------------ satrelu.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Satrelu see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Satrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Satrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c new file mode 100644 index 000000000..03761fcd6 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c @@ -0,0 +1,26 @@ +//===------------------------ sigmoid.c -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Sigmoid see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Sigmoid(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Sigmoid(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sin.c b/third_party/tsingmicro/crt/lib/Tx81/sin.c new file mode 100644 index 000000000..065f57e85 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sin.c @@ -0,0 +1,26 @@ +//===------------------------ Sin.c ---------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Sin see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Sin(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmTranscendental *cmd = TsmNewTranscendental(); + TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Sin(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteTranscendental(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/softplus.c b/third_party/tsingmicro/crt/lib/Tx81/softplus.c new file mode 100644 index 000000000..af1f1f0ee --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/softplus.c @@ -0,0 +1,27 @@ +//===------------------------ softplus.cpp +//------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Softplus see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Softplus(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Softplus(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tanh.c b/third_party/tsingmicro/crt/lib/Tx81/tanh.c new file mode 100644 index 000000000..aecc93431 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tanh.c @@ -0,0 +1,26 @@ +//===------------------------ tanh.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Tanh see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Tanh(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmActivation *cmd = TsmNewActivation(); + TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->Tanh(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteActivation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c new file mode 100644 index 000000000..a2f023620 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c @@ -0,0 +1,31 @@ +//===------------------------ tensornorm.c --------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TensorNorm see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TensorNorm(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->TensorNom(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c new file mode 100644 index 000000000..b01999f77 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c @@ -0,0 +1,26 @@ +//===------------------------ tf32_bf16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_BF16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c new file mode 100644 index 000000000..7369e13d7 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c @@ -0,0 +1,25 @@ +//===------------------------ tf32_fp16.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_FP16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c new file mode 100644 index 000000000..40f8036f3 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c @@ -0,0 +1,25 @@ +//===------------------------ tf32_fp32.c ---------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_FP32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c new file mode 100644 index 000000000..1c0b07546 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c @@ -0,0 +1,26 @@ +//===------------------------ tf32_int16.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT16 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c new file mode 100644 index 000000000..5ca965e47 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c @@ -0,0 +1,26 @@ +//===------------------------ tf32_int32.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT32 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c new file mode 100644 index 000000000..fec628c19 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c @@ -0,0 +1,26 @@ +//===------------------------ tf32_int8.c --------------------------------===// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::TF32_INT8 see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __TF32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, + RND_MODE round) { + // Create command buffer. + TsmConvert *cmd = TsmNewConvert(); + TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + + cmd->TF32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteConvert(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/transpose.c b/third_party/tsingmicro/crt/lib/Tx81/transpose.c new file mode 100644 index 000000000..c9725c56c --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/transpose.c @@ -0,0 +1,31 @@ +//===------------------------ transpose.c ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Transpose see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __Transpose(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + // Create command buffer. + TsmDataMove *cmd = TsmNewDataMove(); + TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + + Data_Shape shape1 = { src_n, src_h, src_w, src_c }; + Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; + cmd->Transpose(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteDataMove(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/wdma.c b/third_party/tsingmicro/crt/lib/Tx81/wdma.c new file mode 100644 index 000000000..1fee20152 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/wdma.c @@ -0,0 +1,41 @@ +//===------------------------ wdma.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Wdma, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +// The arguments list is aligned with TsmConv in Tx81Ops.td +void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, + int shape_c, int stride_n, int stride_h, int stride_w, + uint32_t fmt) { + // Create gemm command buffer. + TsmWdma *wdma = TsmNewWdma(); + TsmWdmaInstr inst = {I_WDMA, + { + 0, + }, + { + 0, + }}; + + wdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); + + wdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, + shape_h, stride_n, shape_n); + + // wdma->Wdma1d(&inst, (uint64_t)src, (uint64_t)dst, shape_c, + // (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteWdma(wdma); +} diff --git a/third_party/tsingmicro/include/CMakeLists.txt b/third_party/tsingmicro/include/CMakeLists.txt new file mode 100644 index 000000000..76c90b65b --- /dev/null +++ b/third_party/tsingmicro/include/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(triton-shared) +add_subdirectory(magic-kernel) +add_subdirectory(tsingmicro-tx81) +# The following 2 dialects are currently unused. +#add_subdirectory(magic-kernel-func) +#add_subdirectory(magic-kernel-instr) diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp new file mode 100644 index 000000000..48e2afbf5 --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp @@ -0,0 +1,192 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "CRunnerUtils.h" +#include "Msan.h" + +#ifndef _WIN32 +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + defined(__DragonFly__) +#include +#else +#include +#endif +#include +#else +#include "malloc.h" +#endif // _WIN32 + +#include +#include +#include +#include +#include +#include + +#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS + +namespace { +template +void stdSort(uint64_t n, V *p) { + std::sort(p, p + n); +} + +} // namespace + +// Small runtime support "lib" for vector.print lowering. +// By providing elementary printing methods only, this +// library can remain fully unaware of low-level implementation +// details of our vectors. Also useful for direct LLVM IR output. +extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); } +extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); } +extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } +extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); } +extern "C" void printString(char const *s) { fputs(s, stdout); } +extern "C" void printOpen() { fputs("( ", stdout); } +extern "C" void printClose() { fputs(" )", stdout); } +extern "C" void printComma() { fputs(", ", stdout); } +extern "C" void printNewline() { fputc('\n', stdout); } + +extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, + UnrankedMemRefType *dstArg) { + DynamicMemRefType src(*srcArg); + DynamicMemRefType dst(*dstArg); + + int64_t rank = src.rank; + MLIR_MSAN_MEMORY_IS_INITIALIZED(src.sizes, rank * sizeof(int64_t)); + + // Handle empty shapes -> nothing to copy. + for (int rankp = 0; rankp < rank; ++rankp) + if (src.sizes[rankp] == 0) + return; + + char *srcPtr = src.data + src.offset * elemSize; + char *dstPtr = dst.data + dst.offset * elemSize; + + if (rank == 0) { + memcpy(dstPtr, srcPtr, elemSize); + return; + } + + int64_t *indices = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *srcStrides = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *dstStrides = static_cast(alloca(sizeof(int64_t) * rank)); + + // Initialize index and scale strides. + for (int rankp = 0; rankp < rank; ++rankp) { + indices[rankp] = 0; + srcStrides[rankp] = src.strides[rankp] * elemSize; + dstStrides[rankp] = dst.strides[rankp] * elemSize; + } + + int64_t readIndex = 0, writeIndex = 0; + for (;;) { + // Copy over the element, byte by byte. + memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize); + // Advance index and read position. + for (int64_t axis = rank - 1; axis >= 0; --axis) { + // Advance at current axis. + auto newIndex = ++indices[axis]; + readIndex += srcStrides[axis]; + writeIndex += dstStrides[axis]; + // If this is a valid index, we have our next index, so continue copying. + if (src.sizes[axis] != newIndex) + break; + // We reached the end of this axis. If this is axis 0, we are done. + if (axis == 0) + return; + // Else, reset to 0 and undo the advancement of the linear index that + // this axis had. Then continue with the axis one outer. + indices[axis] = 0; + readIndex -= src.sizes[axis] * srcStrides[axis]; + writeIndex -= dst.sizes[axis] * dstStrides[axis]; + } + } +} + +/// Prints GFLOPS rating. +extern "C" void printFlops(double flops) { + fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9); +} + +/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC). +extern "C" double rtclock() { +#ifndef _WIN32 + struct timeval tp; + int stat = gettimeofday(&tp, nullptr); + if (stat != 0) + fprintf(stderr, "Error returning time from gettimeofday: %d\n", stat); + return (tp.tv_sec + tp.tv_usec * 1.0e-6); +#else + fprintf(stderr, "Timing utility not implemented on Windows\n"); + return 0.0; +#endif // _WIN32 +} + +extern "C" void *mlirAlloc(uint64_t size) { return malloc(size); } + +extern "C" void *mlirAlignedAlloc(uint64_t alignment, uint64_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#elif defined(__APPLE__) + // aligned_alloc was added in MacOS 10.15. Fall back to posix_memalign to also + // support older versions. + void *result = nullptr; + (void)::posix_memalign(&result, alignment, size); + return result; +#else + return aligned_alloc(alignment, size); +#endif +} + +extern "C" void mlirFree(void *ptr) { free(ptr); } + +extern "C" void mlirAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +extern "C" void *rtsrand(uint64_t s) { + // Standard mersenne_twister_engine seeded with s. + return new std::mt19937(s); +} + +extern "C" uint64_t rtrand(void *g, uint64_t m) { + std::mt19937 *generator = static_cast(g); + std::uniform_int_distribution distrib(0, m); + return distrib(*generator); +} + +extern "C" void rtdrand(void *g) { + std::mt19937 *generator = static_cast(g); + delete generator; +} + +#define IMPL_STDSORT(VNAME, V) \ + extern "C" void _mlir_ciface_stdSort##VNAME(uint64_t n, \ + StridedMemRefType *vref) { \ + assert(vref); \ + assert(vref->strides[0] == 1); \ + V *values = vref->data + vref->offset; \ + stdSort(n, values); \ + } +IMPL_STDSORT(I64, int64_t) +IMPL_STDSORT(F64, double) +IMPL_STDSORT(F32, float) +#undef IMPL_STDSORT + +#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h new file mode 100644 index 000000000..76b04145b --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h @@ -0,0 +1,499 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H +#define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H + +#ifdef _WIN32 +#ifndef MLIR_CRUNNERUTILS_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +// We are building this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#else +// We are using this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_CRUNNERUTILS_EXPORT +#else // _WIN32 +// Non-windows: use visibility attributes. +#define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default"))) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#endif // _WIN32 + +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for Vector type. +//===----------------------------------------------------------------------===// +namespace mlir { +namespace detail { + +constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); } + +constexpr unsigned nextPowerOf2(int n) { + return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); +} + +template +struct Vector1D; + +template +struct Vector1D { + Vector1D() { + static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; +}; + +// 1-D vector, padded to the next power of 2 allocation. +// Specialization occurs to avoid zero size arrays (which fail in -Werror). +template +struct Vector1D { + Vector1D() { + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; +}; +} // namespace detail +} // namespace mlir + +// N-D vectors recurse down to 1-D. +template +struct Vector { + inline Vector &operator[](unsigned i) { return vector[i]; } + inline const Vector &operator[](unsigned i) const { + return vector[i]; + } + +private: + Vector vector[Dim]; +}; + +// 1-D vectors in LLVM are automatically padded to the next power of 2. +// We insert explicit padding in to account for this. +template +struct Vector + : public mlir::detail::Vector1D { +}; + +template +using Vector1D = Vector; +template +using Vector2D = Vector; +template +using Vector3D = Vector; +template +using Vector4D = Vector; + +template +void dropFront(int64_t arr[N], int64_t *res) { + for (unsigned i = 1; i < N; ++i) + *(res + i - 1) = arr[i]; +} + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for StridedMemRef type. +//===----------------------------------------------------------------------===// +template +class StridedMemrefIterator; + +/// StridedMemRef descriptor type with static rank. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[N]; + int64_t strides[N]; + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == N && + "indices should match rank in memref subscript"); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + StridedMemRefType operator[](int64_t idx) { + StridedMemRefType res; + res.basePtr = basePtr; + res.data = data; + res.offset = offset + idx * strides[0]; + dropFront(sizes, res.sizes); + dropFront(strides, res.strides); + return res; + } +}; + +/// StridedMemRef descriptor type specialized for rank 1. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[1]; + int64_t strides[1]; + + template ().begin())> + T &operator[](Range indices) { + assert(indices.size() == 1 && + "indices should match rank in memref subscript"); + return (*this)[*indices.begin()]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } +}; + +/// StridedMemRef descriptor type specialized for rank 0. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + + template ().begin())> + T &operator[](Range indices) { + assert((indices.size() == 0) && + "Expect empty indices for 0-rank memref subscript"); + return data[offset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, offset + 1}; } +}; + +/// Iterate over all elements in a strided memref. +template +class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(&descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::array indices = {}; + + /// Descriptor for the strided memref. + StridedMemRefType *descriptor; +}; + +/// Iterate over all elements in a 0-ranked strided memref. +template +class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) + : elt(descriptor.data + offset) {} + + StridedMemrefIterator &operator++() { + ++elt; + return *this; + } + + reference operator*() { return *elt; } + pointer operator->() { return elt; } + + // There are no indices for a 0-ranked memref, but this API is provided for + // consistency with the general case. + const std::array &getIndices() { + // Since this is a 0-array of indices we can keep a single global const + // copy. + static const std::array indices = {}; + return indices; + } + + bool operator==(const StridedMemrefIterator &other) const { + return other.elt == elt; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Pointer to the single element in the zero-ranked memref. + T *elt; +}; + +//===----------------------------------------------------------------------===// +// Codegen-compatible structure for UnrankedMemRef type. +//===----------------------------------------------------------------------===// +// Unranked MemRef +template +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +//===----------------------------------------------------------------------===// +// DynamicMemRefType type. +//===----------------------------------------------------------------------===// +template +class DynamicMemRefIterator; + +// A reference to one of the StridedMemRef types. +template +class DynamicMemRefType { +public: + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; + + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(0), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(nullptr), strides(nullptr) {} + template + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(N), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {} + explicit DynamicMemRefType(const ::UnrankedMemRefType &memRef) + : rank(memRef.rank) { + auto *desc = static_cast *>(memRef.descriptor); + basePtr = desc->basePtr; + data = desc->data; + offset = desc->offset; + sizes = rank == 0 ? nullptr : desc->sizes; + strides = sizes + rank; + } + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == rank && + "indices should match rank in memref subscript"); + if (rank == 0) + return data[offset]; + + int64_t curOffset = offset; + for (int dim = rank - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + DynamicMemRefIterator begin() { return {*this, offset}; } + DynamicMemRefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + DynamicMemRefType operator[](int64_t idx) { + assert(rank > 0 && "can't make a subscript of a zero ranked array"); + + DynamicMemRefType res(*this); + --res.rank; + res.offset += idx * res.strides[0]; + ++res.sizes; + ++res.strides; + return res; + } + + // This operator* can be used in conjunction with the previous operator[] in + // order to access the underlying value in case of zero-ranked memref. + T &operator*() { + assert(rank == 0 && "not a zero-ranked memRef"); + return data[offset]; + } +}; + +/// Iterate over all elements in a dynamic memref. +template +class DynamicMemRefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + DynamicMemRefIterator(DynamicMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(&descriptor) { + indices.resize(descriptor.rank, 0); + } + + DynamicMemRefIterator &operator++() { + if (descriptor->rank == 0) { + offset = -1; + return *this; + } + + int dim = descriptor->rank - 1; + + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + + if (dim < 0) { + offset = -1; + return *this; + } + + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::vector &getIndices() { return indices; } + + bool operator==(const DynamicMemRefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const DynamicMemRefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::vector indices = {}; + + /// Descriptor for the dynamic memref. + DynamicMemRefType *descriptor; +}; + +//===----------------------------------------------------------------------===// +// Small runtime support library for memref.copy lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, ::UnrankedMemRefType *src, + ::UnrankedMemRefType *dst); + +//===----------------------------------------------------------------------===// +// Small runtime support library for vector.print lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); + +//===----------------------------------------------------------------------===// +// Small runtime support library for timing execution and printing GFLOPS +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops); +extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock(); + +//===----------------------------------------------------------------------===// +// Runtime support library for random number generation. +//===----------------------------------------------------------------------===// +// Uses a seed to initialize a random generator and returns the generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s); +// Returns a random number in the range of [0, m). +extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *, uint64_t m); +// Deletes the random number generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *); + +//===----------------------------------------------------------------------===// +// Runtime support library to allow the use of std::sort in MLIR program. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortI64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF32(uint64_t n, StridedMemRefType *vref); +#endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/third_party/tsingmicro/include/ExecutionEngine/Msan.h b/third_party/tsingmicro/include/ExecutionEngine/Msan.h new file mode 100644 index 000000000..ee94660ae --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/Msan.h @@ -0,0 +1,35 @@ +//===- Msan.h - Utils related to the memory sanitizer ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares and defines macros related to msan. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_MSAN_H +#define MLIR_EXECUTIONENGINE_MSAN_H + +// Memory sanitizer currently can't be enabled for the jit-compiled code, and +// to suppress msan warnings we need to unpoison pointers and pointed-to +// datastructures before they can be accessed. + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_feature(memory_sanitizer) && !defined(MLIR_MEMORY_SANITIZER) +#define MLIR_MEMORY_SANITIZER +#endif + +#if defined(MLIR_MEMORY_SANITIZER) +#include +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) __msan_unpoison((p), (s)) +#else // Memory sanitizer: OFF +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) +#endif // MLIR_MEMORY_SANITIZER + +#endif // MLIR_EXECUTIONENGINE_MSAN_H diff --git a/third_party/tsingmicro/include/ExecutionEngine/version.txt b/third_party/tsingmicro/include/ExecutionEngine/version.txt new file mode 100644 index 000000000..c3f15e55e --- /dev/null +++ b/third_party/tsingmicro/include/ExecutionEngine/version.txt @@ -0,0 +1 @@ +https://github.com/llvm/llvm-project/commit/3be3883e6d67bf908fd12b51219075293ebb3dff diff --git a/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td new file mode 100644 index 000000000..a351a091b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td @@ -0,0 +1,19 @@ +//===------------------- MagicKernelFuncOps.td ----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Common abstraction layer for non-instruction driven ML accelerator. +// +// The target glue layer that translates target independent kernel operations +// into NPU like APIs call (There are other jargons such as intrinsic or driver +// functions etc). +// +// The NPU APIs are categories by data type, like traditional compilers, integer +// and floating point function unit are separated, so for every MK(MagicKernel) +// op, it is lowered to 2 MKF(MagicKernelFunc) which are integer version and +// floating point version. +// +//===----------------------------------------------------------------------===// \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td new file mode 100644 index 000000000..1053afae5 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td @@ -0,0 +1,13 @@ +//===------------------- MagicKernelInstrOps.td ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Common abstraction layer for instruction driven ML accelerator. +// +// The target glue layer that translates target independent kernel operations +// into intrinsics which fits LLVM dialect lowering path. +// +//===----------------------------------------------------------------------===// \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt new file mode 100644 index 000000000..cece7d89b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(LinalgToMK) +add_subdirectory(CoreDialectsToMK) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt new file mode 100644 index 000000000..69690dec7 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CMakeLists.txt @@ -0,0 +1,10 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +# All rights reserved. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name CoreDialectsMK) +add_public_tablegen_target(CoreDialectsToMKConversionPassIncGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h new file mode 100644 index 000000000..69750d402 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h @@ -0,0 +1,27 @@ +//===------------------- CoreDialectsToMK.h -------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass is the wrap all pass that populates all the conversion patterns +// from core dialects such as linalg, memref, buf etc to mk dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H +#define TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createCoreDialectsToMKPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_CORE_DIALECTS_TO_MK_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h new file mode 100644 index 000000000..7e1982c3b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.h @@ -0,0 +1,26 @@ +//===------------------- CoreDialectsToMK.h -------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Wrap all the conversion from core dialects to backend dialects(MK etc). +// +//===----------------------------------------------------------------------===// + +#ifndef CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H +#define CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H + +#include "magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // CORE_DIALECTS_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td new file mode 100644 index 000000000..d4f5fa677 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CoreDialectsToMK/Passes.td @@ -0,0 +1,18 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef CORE_DIALECTS_TO_MK_CONVERSION_PASSES +#define CORE_DIALECTS_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def CoreDialectsToMK : Pass<"core-dialects-to-mk", "mlir::ModuleOp"> { + let summary = "Convert core dialects including Linalg, Memref etc to MK"; + let constructor = "triton::createCoreDialectsToMKPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt new file mode 100644 index 000000000..76b9d9114 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name LinalgToMK) +add_public_tablegen_target(LinalgToMKConversionPassIncGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h new file mode 100644 index 000000000..1a2721d38 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h @@ -0,0 +1,36 @@ +//===------------------- LinalgToMK.h -------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering all linalg ops into mk ops. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_LINALG_TO_MK_H +#define ZTC_CONVERSION_LINALG_TO_MK_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + +void populateLinalgToMKCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateLinalgToMKConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr> createLinalgToMKPass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h new file mode 100644 index 000000000..7c45210e6 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TO_MK_CONVERSION_PASSES_H +#define LINALG_TO_MK_CONVERSION_PASSES_H + +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // LINALG_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td new file mode 100644 index 000000000..b4f39500c --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/Passes.td @@ -0,0 +1,19 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TO_MK_CONVERSION_PASSES +#define LINALG_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def LinalgToMK : Pass<"linalg-to-mk", "mlir::ModuleOp"> { + let summary = "Convert linalg operations into magic kernel operations"; + + let options = []; +} + +#endif diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt new file mode 100644 index 000000000..437811f2a --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS MagicKernelOps.td) +mlir_tablegen(MagicKernelDialect.h.inc -gen-dialect-decls -dialect=mk) +mlir_tablegen(MagicKernelDialect.cpp.inc -gen-dialect-defs -dialect=mk) +mlir_tablegen(MagicKernelOps.h.inc -gen-op-decls) +mlir_tablegen(MagicKernelOps.cpp.inc -gen-op-defs) + +set(LLVM_TARGET_DEFINITIONS MagicKernelTypes.td) +mlir_tablegen(MagicKernelTypes.h.inc -gen-typedef-decls) +mlir_tablegen(MagicKernelTypes.cpp.inc -gen-typedef-defs) + +add_public_tablegen_target(MagicKernelTableGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td new file mode 100644 index 000000000..0f8018678 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td @@ -0,0 +1,15 @@ +//===------------------- MagicKernelAttrDefs.td ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_ATTR_DEFS +#define MAGIC_KERNEL_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + + + +#endif // MAGIC_KERNEL_ATTR_DEFS \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h new file mode 100644 index 000000000..c9bb47440 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h @@ -0,0 +1,34 @@ +//===------------------- MagicKernelDialect.h -----------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ +#define MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// MagicKernel Operations +//===----------------------------------------------------------------------===// +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "magic-kernel/Dialect/IR/MagicKernelOps.h.inc" + + +#endif // MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td new file mode 100644 index 000000000..ade7c11c3 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td @@ -0,0 +1,44 @@ +//===------------------- MagicKernelDialect.td ----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_DIALECT +#define MAGIC_KERNEL_DIALECT + +include "mlir/IR/OpBase.td" + +def MagicKernelDialect : Dialect { + let name = "mk"; + + let cppNamespace = "::mlir::mk"; + + let summary = "The Magic Kernel IR in MLIR"; + + let description = [{ + Magic Kernel Dialect. + + Dependent Dialects: + * Memref + * copy, alloc + * Bufferization + * to_tensor + }]; + + let dependentDialects = [ + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + // let hasConstantMaterializer = 1; + // let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "magic-kernel/Dialect/IR/MagicKernelTypes.td" + +#endif // MAGIC_KERNEL_DIALECT \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td new file mode 100644 index 000000000..a9643ba9b --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td @@ -0,0 +1,284 @@ +//===------------------- MagicKernelOps.td --------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// The abstract layer between MLIR core dialects and the lower target specific +// dialects of MagicKernelFunc and MagicKernelInstr. +// +// Compare to higher level MLIR dialects such as memref, arith, affine etc, the +// granularity of MK dialect is more suitable to map into ML accelerators. For +// example, tt.load is lowered to arith + memref.reinterpret_cast + memref.alloc +// + memref.copy + bufferization.to_tensor by decoding hidden high level info +// into detailed info carried in those core MLIR dialects. +// If we convert tt.load to mk.alloc + mk.load, we have to redo all the analysis +// and info constructions which triton-shared already does, so that we should +// generate mk.alloc + mk.load from the core dialects to avoid reconstructing +// the information. +// By doing so, we can lower arith + memref.reinterpret_cast + memref.copy + +// buf.to_tensor into mk.load, and lower arith + memref.alloc into mk.alloc. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_OPS +#define MAGIC_KERNEL_OPS + +include "magic-kernel/Dialect/IR/MagicKernelTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Bufferable type. +//===----------------------------------------------------------------------===// + +def TensorOrMemref : + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + + +class MKOp traits = []> : + Op { +} + +class MKUnElemWiseOp : MKOp { + let summary = "Element wise unary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + +class MKBinElemWiseOp : MKOp { + let summary = "Element wise binary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src0, + AnyTensor:$src1, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + +class MKTerElemWiseOp : MKOp { + let summary = "Element wise binary operation: $mnemonic"; + + let arguments = ( + ins + AnyTensor:$src0, + AnyTensor:$src1, + AnyTensor:$src2, + BoolAttr:$is_atomic + ); + + let results = (outs AnyTensor:$dst); +} + + +// ============================================================================= +// Memory allocation ops +// ============================================================================= + +def AllocOp : MKOp<"alloc", []> { + let summary = "Allocate a consecutive memory from given addressing space"; + + let description = [{ + It may or may not generate target intrinsic call or instruction, the + lowering from this operator to lower level operator is target specific. + }]; + + let arguments = ( + ins + I32Attr:$addr_space, // The addressing space + I64ArrayAttr:$dims // The size of memory to be allocated + ); + + // Return the pointer of the allocated memory + let results = (outs AnyRankedOrUnrankedMemRef:$ptr); +} + +// ============================================================================= +// Load/Store Ops +// ============================================================================= + +// Unit and strided memory load +def LoadOp : MKOp<"load", []> { + let summary = "Load from a memory with optional strides"; + + let description = [{ See RISC-V RVV unit/strided memory load for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be loaded, can be dynamic + I64ArrayAttr:$strides, // The strides in each rank, can be dynamic + BoolAttr:$mask // element is not loaded if mask[i] == 0 + ); + + let results = (outs AnyTensor:$result); + + let assemblyFormat = [{ + $ptr `,` attr-dict `:` type($ptr) `->` type($result) + }]; +} + +// Index memory load +def IndexLoadOp : MKOp<"iload", [ +]> { + let summary = "Load from a memory with indexed offset"; + + let description = [{ See RISC-V RVV index memory load for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be loaded + AnyTensor:$index, // The tensor contains memory offset for each element + BoolAttr:$mask // element is not loaded if mask[i] == 0 + ); + + let results = (outs MKType:$result); +} + +// Unit and strided memory store +def StoreOp : MKOp<"store", [MemoryEffects<[MemWrite]>]> { + let summary = "Store to a memory with optional strides"; + + let description = [{ See RISC-V RVV unit/strided memory store for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be stored + I64ArrayAttr:$strides, // The strides in each rank + BoolAttr:$mask // element is not write to dest if mask[i] == 0 + ); + + let assemblyFormat = [{ + $ptr `,` attr-dict `:` type($ptr) + }]; +} + +// Index memory store +def IndexStoreOp : MKOp<"istore", [MemoryEffects<[MemWrite]>]> { + let summary = "Store to a memory with indexed offset"; + + let description = [{ See RISC-V RVV index memory store for detail }]; + + let arguments = ( + ins + AnyRankedOrUnrankedMemRef:$ptr, // The base ptr + I32Attr:$addr_space, // The address space + I64ArrayAttr:$dims, // The shape to be stored + AnyTensor:$index, // The tensor contains memory offset for each element + BoolAttr:$mask // element is not write to dest if mask[i] == 0 + ); +} + +// ============================================================================= +// Dot op +// ============================================================================= + +def DotOp : MKOp<"dot", [DestinationStyleOpInterface]> { + let summary = "Inner production of 2 vectors"; + + let description = [{ + TODO: It is currently one to one mapping from upper dialect tt.dot. + }]; + + let arguments = ( + ins + TensorOrMemref:$a, // Matrix A + TensorOrMemref:$b, // Matrix B + Optional:$c, // Optional accumulation matrix C + // Zeroes buffer which can be used to fill $d + // FIXME: Whether need add side effect to source operands? + Arg:$zeroes + //DefaultValuedAttr:$inputPrecision, + // DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs Variadic:$d); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getZeroesMutable(); + } + }]; + + // let hasVerifier = 1; +} + +// ============================================================================= +// Reduction ops +// ============================================================================= + +def ArgMaxOp : MKOp<"argmax", [Pure]> {} +def ArgMinOp : MKOp<"argmin", [Pure]> {} +def ReduceMaxOp : MKOp<"reduce_max", [Pure]> {} +def ReduceMinOp : MKOp<"reduce_min", [Pure]> {} +def ReduceOp : MKOp<"reduce", [Pure]> {} +def SumOp : MKOp<"sum", [Pure]> {} +def XorSumOp : MKOp<"xor_sum", [Pure]> {} + + +// ============================================================================= +// Scan/Sort Ops +// ============================================================================= + +def SortOp : MKOp<"sort", [Pure]> {} +def GatherOp : MKOp<"gather", [Pure]> {} + + +// ============================================================================= +// Unary/Binary/Ternary Element-wise Math Ops +// ============================================================================= + +def AbsOp : MKUnElemWiseOp<"abs">; +def AddOp : MKBinElemWiseOp<"add">; +def AndOp : MKBinElemWiseOp<"and">; +def CDivOp : MKBinElemWiseOp<"cdiv">; +def CeilOp : MKUnElemWiseOp<"ceil">; +def ClampOp : MKUnElemWiseOp<"clamp">; +def CosOp : MKUnElemWiseOp<"cos">; +def DivOp : MKBinElemWiseOp<"div">; +def ErfOp : MKUnElemWiseOp<"erf">; +def ExpOp : MKUnElemWiseOp<"exp">; +def Exp2Op : MKUnElemWiseOp<"exp2">; +def FdivOp : MKBinElemWiseOp<"fdiv">; +def FloorOp : MKUnElemWiseOp<"floor">; +def FmaOp : MKTerElemWiseOp<"fma">; +def LogOp : MKUnElemWiseOp<"log">; +def Log2Op : MKUnElemWiseOp<"log2">; +def MaxOp : MKUnElemWiseOp<"max">; +def MinOp : MKUnElemWiseOp<"min">; +def OrOp : MKBinElemWiseOp<"or">; +def RsqrtOp : MKUnElemWiseOp<"rsqrt">; +def SigmoidOp : MKUnElemWiseOp<"sigmoid">; +def SinOp : MKUnElemWiseOp<"sin">; +def SqrtOp : MKUnElemWiseOp<"sqrt">; +def SqrtRnOp : MKUnElemWiseOp<"sqrt_rn">; +def XorOp : MKBinElemWiseOp<"xor">; +// def UmulhiOp : MKOp<"umulhi", [Pure]> {} + + +#endif // MAGIC_KERNEL_OPS \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td new file mode 100644 index 000000000..7c87e5cf2 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td @@ -0,0 +1,102 @@ +//===------------------- MagicKernelTypes.td ------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MAGIC_KERNEL_TYPES_TD +#define MAGIC_KERNEL_TYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "magic-kernel/Dialect/IR/MagicKernelDialect.td" + +// +// Types +// +class MKTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def MKFloat : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def MKFloatTensor : RankedTensorOf<[MKFloat]>; +def MKFloatLike : AnyTypeOf<[MKFloat, MKFloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def MKBoolTensor : RankedTensorOf<[I1]>; +def MKBoolLike : AnyTypeOf<[I1, MKBoolTensor]>; + +// Integer Type +def I4 : I<4>; +def MKInt : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def MKIntTensor : RankedTensorOf<[MKInt]>; +def MKIntLike : AnyTypeOf<[MKInt, MKIntTensor]>; + +// I32 Type +// MKI32 -> I32 +// MKI32Tensor -> I32Tensor +def MKI32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// MKI64 -> I64 +// MKI64Tensor -> I64Tensor +def MKI64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class MKPtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `MKPtrOf`) +def MKPtrType : MKTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def MKPtr : MKPtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def MKPtrTensor : RankedTensorOf<[MKPtr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def MKPtrLike : AnyTypeOf<[MKPtr, MKPtrTensor]>; + +// Tensor Type +def MKFpIntTensor : RankedTensorOf<[MKFloat, MKInt]>; +def MKTensor : RankedTensorOf<[MKFloat, MKInt, MKPtr]>; + +// Pointer Type to Tensor Type: `ptr>` +def MKTensorPtr : MKPtrOf<[MKTensor]>; + +// Any Type in Magic Kernel IR +def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; + +#endif // MAGIC_KERNEL_TYPES_TD \ No newline at end of file diff --git a/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h b/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..789a0c8a8 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,26 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file declares the implementation of the BufferizableOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H +#define _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace mk { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace mk +} // namespace mlir + +#endif // _MK_DIALECT_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h new file mode 100644 index 000000000..6d310d93d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_MASKANALYSIS_H +#define TRITON_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/LogicalResult.h" + +#include + +namespace mlir { + +class OpBuilder; + +namespace triton { +// Data structure used to decode the pattern in a mask used for load and store. +// start and end field represent the start and end index of a range (produced +// by make_range, addi, etc.). While multi-dimensional data is possible, we +// assume range comparison can only be done on 1 dimension at a time (and +// results of range comparions across dimensions can be combined), hence start +// and end are not vectors. dims represents the real access size for ld/st +// (instead of the tensor/memref size specified by the IR). scalar is a shortcut +// used when the entire state contains a single scalar value. +// +// The general lifetime of this data structure is roughly: +// 1. A range is created by make_range and optionally operated on by addi w/ +// result of splat, expand_dims, etc. During this phase, either (1) both start +// and end are populated, or (2) scalar is populated. Only one of the dimensions +// (that contains the range) can have dim > 1. +// 2. Result from step 1 is compared with a another MaskState that represents a +// scalar value. The resulting state only has dims populated. +// 3. Optionally, result from step 2 can be broadcasted and anded with other +// results from step 2. The resulting state only has dims populated. +// +// Example of creating 2D mask: +// mask = (rows[:, None] < M) & (cols[None, :] < N) +struct MaskState { + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + OpFoldResult scalar; + const bool useUnsafeMask; + + void dump() const; + + MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {} + + int64_t getRank() const { return dims.size(); } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { return !start && !end && !scalar && dims.size() != 0; } + + // Recursively parse a Value; call the coresponding function based on the + // defining operation and Value type + LogicalResult parse(Value operand, const Location loc, OpBuilder &builder); + + tensor::ExtractSliceOp getExtractSlice(Value source, const Location loc, + OpBuilder &builder) const; + + memref::SubViewOp getSubview(Value source, const Location loc, + OpBuilder &builder) const; + + std::pair + getSideBySideSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const; + + std::pair + getStackedSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const; + +private: + // ------- + // Utility functions to operate on MaskState + // ------- + LogicalResult addStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, + Location loc, OpBuilder &builder); + + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, + Location loc, OpBuilder &builder); + // ------- + // Helper functions to parse values to populate MaskState + // ------- + + LogicalResult parseExtSI(arith::ExtSIOp op, const Location loc, + OpBuilder &builder); + + // Operand is the result of a constant + // Get the value of the constant and assign it to scalar. + LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder); + + // Operand is the result of addi + // One and only one of the operands should be a scalar. Increment both start + // and end, dims remains unchanged, and scalar is empty. + LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder); + // Operand is the result of andi + // Each of the result state dims is smaller of the two operands' dims. + // Insert instruction if needed to get new dims. + LogicalResult parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder); + + // Operand is the result of cmpi + // Assume only one of the dimensions has size > 1. Only support slt/ult, and + // sge against 0 for now. For that dimension, we have three cases: + // 1. Constant comparison with both left and right-hand sides being scalars. + // Calculate this new dim as a compare and select. + // I.e. dim = lhs < rhs ? end : 0 + // 2. Left-hand side is not a scalar, and the right-hand side is. + // 2.a. Predicate is slt/ult. Calculate this new dim as: + // dim = max(min(end, value), start) - start + // 2.b. Predicate is sge against 0. Mask analysis already has an + // assumption that the mask starts at 0, so evaluate this to true + // and calculate this new dim as: dim = end + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder); + // Operand is the result of make_range + // Set start and end accordingly; step size must be 1. + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, + OpBuilder &builder); + // Operand is the result of broadcast + // Change dims only; assume only applies to tensors. + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, OpBuilder &builder); + // Operand is the result of splat + // Assume only applies to scalar. start and end are left empty; scalar will + // be assigned, and dims will be updated. + LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder); + // Operand is the result of expand_dims + // Insert additional dims; start and end do not change and correspond to the + // dimension that contains the range. + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, OpBuilder &builder); + + LogicalResult parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h new file mode 100644 index 000000000..2e32b7894 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H +#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H + +#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include + +namespace mlir { + +class OpBuilder; + +// Return integer if ofr is an IntegerAttr. Note that this function differs +// from getConstantIntValue, which returns an integer if ofr is the constant +// result of an operation too. +std::optional getIntAttr(const OpFoldResult ofr); + +// Return if ofr contains a constant zero, either represented by an integer +// attribute or a constant value. +bool hasConstZero(const OpFoldResult ofr); + +// Create a value of index type if necessary from an OpFoldResult. +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b); + +// Create a vector of values of index type if necessary from an array of +// OpFoldResults. +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b); + +// Process addition of two OFRs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +// Produce result = lhs - rhs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +// Process multiplication of two OFRs. If both OFRs are Integer Attributes, +// result is an Integer Attribtue. Otherwise, insert the arith.muli +// instruction if needed and use its result Value. +OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, + const Location loc, OpBuilder &b); + +OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); + +OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const arith::CmpIPredicate pred, const OpFoldResult trueVal, + const OpFoldResult falseVal, const Location loc, OpBuilder &b); +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h new file mode 100644 index 000000000..5a95ebda9 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/PtrAnalysis.h @@ -0,0 +1,271 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_PTRANALYSIS_H +#define TRITON_ANALYSIS_PTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +class ConversionPatternRewriter; + +namespace triton { + +struct ModuloState { + Value size; + + // offset is used to determine the wraparound point for patterns like: + // offset + (tl.arange(0, 256) % 12) + // The current code assumes that the modulo operator always runs last, e.g: + // (offset + tl.arange(0, 256)) % 12 + // This is not used at the moment as there haven't been enough use cases and + // the implementation is quite complex. + // OpFoldResult offset; + + static constexpr char const *WraparoundAttr = "ptr.wraparound_type"; + static constexpr char const *WraparoundStacked = "stacked"; + static constexpr char const *WraparoundSideBySide = "side_by_side"; +}; + +// Data structure used to decode pointer arithmetics and potentially to be +// translate it into memref. offsets, sizes, and strides are in unit of elements +// in a linearly laid-out memory, which is the same as pointer arithmetic +// operations in Triton language. scalar is a shortcut used when the entire +// state describes a single scalar value. source is the base pointer. +class PtrState { + + OpFoldResult + accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const; + +public: + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + + SmallVector> modulos; + + Value source; + Value scalar; + + int64_t getRank() const; + + bool isEmpty() const; + + bool hasModulo() const; + + MemRefType getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape, + bool useDynamicStrides = false) const; + + // Process addition of two PtrStates. + void addState(const PtrState &lhsState, const PtrState &rhsState, + Location loc, ConversionPatternRewriter &rewriter); + + // Process multiplication of two PtrStates + void mulState(const PtrState &lhsState, const PtrState &rhsState, + const Location loc, ConversionPatternRewriter &rewriter); + + // Produce a reinterpret cast based on the current PtrState. Additional + // instructions may be inserted in calculating the final offset. + memref::ReinterpretCastOp + createCastOp(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createSideBySideCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createStackedCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; +}; + +class PtrAnalysis { +public: + using IndexMapSet = std::map>; + + // Recursively parse a Value; call the corresponding + // function based on the defining operation and argument type. + static void + visitOperand(Value operand, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + static void + visitOperandAdd(arith::AddIOp addOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + static void + visitOperandMul(arith::MulIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void + visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + static void + visitOperandMakeRange(triton::MakeRangeOp rangeOp, PtrState &state, + Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + static void + visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, PtrState &state, + const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of soure and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + static void + visitOperandBroadcast(triton::BroadcastOp broadcastOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + static void + visitOperandSplat(triton::SplatOp splatOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + static void + visitOperandConstSplat(arith::ConstantOp op, PtrState &state, + const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset wil be added + static void + visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of reinterpret_cast. + // Main assumptions: + // None + // Expected result: + // Directly grab all corresponding fields from reinterpret_cast. + static void + visitOperandReintCast(memref::ReinterpretCastOp reintCastOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of tt.advance. + // Main assumptions: + // The source of the tt.advance has been mapped to a reinterpret_cast + // Expected result: + // Directly grab all corresponding fields from reinterpret_cast. + // Add the offsets multiplied by the strides to the final offsets. + static void rewriteAdvanceOp(triton::AdvanceOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs); + + // Parse the state of AddPtrOp, insert any instruction needed to + // calculate strides and offsets, build PtrState for this operand, and record + // PtrState for knownPtrs. + static void rewriteAddptrOp(triton::AddPtrOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + static void + rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &knownPtrs); + + static void rewriteForOp(scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &knownPtrs); + + static Value getScalarMemRef(Value ptr, Value memRef, const Location loc, + ConversionPatternRewriter &rewriter); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h new file mode 100644 index 000000000..39c3055a5 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Analysis/UseAnalysis.h @@ -0,0 +1,119 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_USEANALYSIS_H +#define TRITON_ANALYSIS_USEANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createTritonUseAnalysisPass(); + +enum class UseType { + Undefined, // Initial state + DataUse, // value used for tensor computation only + MetaUse, // value used for metadata only + MixUse // value used for both tensor computation and metadata +}; + +struct UseInfo : public dataflow::AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UseInfo) + using AbstractSparseLattice::AbstractSparseLattice; + + // Lattice state transfer function + ChangeResult meetUseType(const UseType &other) { + if (other == UseType::Undefined) + return ChangeResult::NoChange; + + switch (type) { + case UseType::Undefined: + type = other; + return ChangeResult::Change; + case UseType::DataUse: + case UseType::MetaUse: + if (type == other) { + return ChangeResult::NoChange; + } else { + type = UseType::MixUse; + return ChangeResult::Change; + } + case UseType::MixUse: + return ChangeResult::NoChange; + default: + llvm_unreachable("bad type"); + } + } + + ChangeResult meet(const AbstractSparseLattice &other) override { + auto rhs = reinterpret_cast(&other); + return meetUseType(rhs->type); + } + + void print(raw_ostream &os) const override { + switch (type) { + case UseType::DataUse: + os << "DataUse"; + break; + case UseType::MetaUse: + os << "MetaUse"; + break; + case UseType::MixUse: + os << "MixUse"; + break; + default: + os << "Undefined"; + } + } + + UseType type = UseType::Undefined; +}; + +class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override { return; } + + void visitCallOperand(OpOperand &operand) override { return; } + + void setToExitState(UseInfo *lattice) override { + lattice->type = UseType::Undefined; + } + +private: + void propagateUse(UseInfo *lattice, const UseType &type) { + auto changed = lattice->meetUseType(type); + propagateIfChanged(lattice, changed); + } + + void propagateResults(UseInfo *lattice, ArrayRef results) { + auto changed = ChangeResult::NoChange; + for (auto result : results) + changed |= lattice->meet(*result); + propagateIfChanged(lattice, changed); + } +}; + +// Use SparseBackwardDataAnalysis to identify operations whose results are used +// as data tensor operations, meta operations (address calculation, +// broadcasting/splating constant, etc.), or both. For operations used as both +// purposes, clone them so that the remaining pass built on +// ConversionPatternRewriter can replace all tensor producers cleanly and simply +// delete meta data producers. +LogicalResult runUseAnalysis(triton::FuncOp &funcOp); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOAFFINE_TRITONUSEANALYSIS_H diff --git a/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h new file mode 100644 index 000000000..e17cacd6a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -0,0 +1,274 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H +#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include + +namespace mlir { + +class OpBuilder; + +namespace tts { + +const extern std::string ptrAnalysisAttr; + +// Data structure used to decode pointer arithmetics. offsets, sizes, and +// strides are in unit of elements in a linearly laid-out memory, which is the +// same as pointer arithmetic operations in Triton language. scalar is a +// shortcut used when the entire state describes a single scalar value. source +// is the base pointer. If order is present, PtrState describes block pointer; +// otherwise it describes non-block pointers. When it describes block pointer, +// shape field means the same field as tt.make_tensor_ptr; when it describes a +// non-block pointer, shape field indicates how address wraps around (i.e., +// modulo); a constant 0 indicates no modulo for the dimension. +struct PtrState { + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + SmallVector shape; + SmallVector order; + + Value source; + Value scalar; + + int32_t getRank() const; + + bool isEmpty() const; + + bool hasModulo() const; + + bool dimHasModulo(uint32_t dim) const; + + bool isBlockPtr() const; + + void dump() const; + + // Process addition of two PtrStates. + LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder); + + // Process multiplication of two PtrStates + LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder); + + tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc); +}; + +class PtrAnalysis { + // This function is internally used by getLoopIterArgPtrState and + // getLoopResultPtrState to get the correct PtrState for either an iter-arg or + // a loop's result. + // + // A PtrState of an scf.for's iter-arg is the same as its corresponding + // init-arg, except that the strides and offsets have to point to the loop's + // iter-args that were created to carry the offsets and strides. + // + // For instance, for a pointer with index i and rank 2, 4 additional args + // starting at index i + 1 are created. The PtrState's strides and offsets + // value of the pointer's iter-arg must point to these 4 additionally created + // iter-args. + // + // A similar process is used for getting the PtrState of the loop's i'th + // result: its strides and offsets have to point to the corresponding stride + // and offset values returned by the loop. + PtrState reconcileLoopPtrState( + scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state, + llvm::function_ref getReplacementVal); + + DenseSet maybeStructuredArgs; + +public: + void initializeMaybeStructuredArgs(Operation *op); + + llvm::SmallDenseMap knownPtrs; + + IRMapping ptrMap; + + // Recursively parse a Value; call the corresponding + // function based on the defining operation and argument type. + LogicalResult visitOperand(Value operand, PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is a result of an scf.for. Such cases occur when there are multiple + // levels of nested loops where the results of the inner scf.for (pointer) are + // yielded by the outer loop. + LogicalResult visitOperandForOp(scf::ForOp forOp, Value operand, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state, + const Location loc, OpBuilder &builder); + + LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrState &state, Location loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of soure and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state, + const Location loc, OpBuilder &builder); + + LogicalResult visitOperandExtSI(arith::ExtSIOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset wil be added + LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of tts.make_tptr. + // Main assumptions: + // This function is only called when rewriting a loop + // Expected result: + // Directly grab all corresponding fields from tts.make_tptr. + LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of tt.make_tensor_ptr. + // Expected result: + // Parse source pointer and grab results + LogicalResult visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder); + + // Get the computed PtrState for the forOp's init-arg at the provided index. + FailureOr getLoopInitArgPtrState(scf::ForOp forOp, size_t index); + + // Get the computed PtrState for the forOp's iter-arg at the provided index. + FailureOr getLoopIterArgPtrState(scf::ForOp forOp, size_t index); + + // Get the computed PtrState for the forOp's result at the provided index. + FailureOr getLoopResultPtrState(scf::ForOp forOp, size_t index); + + // After PtrAnalysis finishes, rewrite the GetStructuredStateOp by creating + // the correct initialization ops for offsets and strides and passing them to + // any loop's init-args. + LogicalResult rewriteGetStructuredStateOp(tts::GetStructuredStateOp op); + + // Parse the state of AddPtrOp, insert any instruction needed to + // calculate strides and offsets, build PtrState for this operand, and record + // PtrState for knownPtrs. + LogicalResult rewriteAddptrOp(triton::AddPtrOp op); + + LogicalResult rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op); + + LogicalResult rewriteAdvanceOp(triton::AdvanceOp op); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + LogicalResult + rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor); + + // Rewrite eligible tt.addptr in loop init args so loop can update the such + // pointers over iterations. Insert any instruction needed to calculate + // strides, offsets, and modulos. + LogicalResult rewriteForOp(scf::ForOp op); + + LogicalResult rewriteLoadOp(triton::LoadOp op, bool useUnsafeMask = false); + + LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false); + + LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false); +}; + +} // namespace tts + +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt new file mode 100644 index 000000000..60180abfb --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) +add_subdirectory(TritonArithToLinalg) +add_subdirectory(StructuredToMemref) +add_subdirectory(TritonToCoreDialects) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt new file mode 100644 index 000000000..83ff64d36 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name StructuredToMemref) +add_public_tablegen_target(StructuredToMemrefConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h new file mode 100644 index 000000000..198675b12 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_STRUCTURED_TO_MEMREF_CONVERSION_PASSES_H +#define TRITON_STRUCTURED_TO_MEMREF_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td new file mode 100644 index 000000000..0f2f08a6d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/Passes.td @@ -0,0 +1,10 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def StructuredToMemref : Pass<"structured-to-memref", "mlir::ModuleOp"> { + let summary = "Convert triton structured pointer ops to memref"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h new file mode 100644 index 000000000..8c67c9ec0 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h @@ -0,0 +1,24 @@ +#ifndef TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H +#define TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +class TypeConverter; +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" + +void populateStructuredToMemrefConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); + +std::unique_ptr> createStructuredToMemrefPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_STRUCTUREDTOMEMREF_STRUCTUREDTOMEMREF_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt new file mode 100644 index 000000000..85076bd1c --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonArithToLinalg) +add_public_tablegen_target(TritonArithToLinalgConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h new file mode 100644 index 000000000..1d8d2696c --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h @@ -0,0 +1,2119 @@ +#ifndef TRITON_CONVERSION_PATTERNS +#define TRITON_CONVERSION_PATTERNS + +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Analysis/PtrAnalysis.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace triton; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +static SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +static Value getTransposedValue(Value source, const Location loc, + ConversionPatternRewriter &rewriter) { + + auto sourceType = cast(source.getType()); + auto sourceRank = sourceType.getRank(); + + SmallVector perm(sourceRank); + std::iota(std::begin(perm), std::end(perm), 0); + std::swap(perm[sourceRank - 1], perm[sourceRank - 2]); + + SmallVector transposedShape(sourceType.getShape()); + std::swap(transposedShape[sourceRank - 1], transposedShape[sourceRank - 2]); + + Value transposeInit = rewriter.create( + loc, transposedShape, sourceType.getElementType()); + + Value transpose = + rewriter.create(loc, source, transposeInit, perm) + .getResults()[0]; + + return transpose; +} + +// for IntLike and FloatLike types +static std::optional getBitWidth(Type a) { + if (auto type = dyn_cast(a)) { + auto elementType = type.getElementType(); + if (elementType.isIntOrFloat()) { + return type.getElementType().getIntOrFloatBitWidth(); + } + return std::nullopt; + } + + if (a.isIntOrFloat()) + return a.getIntOrFloatBitWidth(); + + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// Op Lowering Patterns +//===----------------------------------------------------------------------===// + +namespace { + +//----------------------------- +// Begin of monolithic only +//----------------------------- +struct AdvanceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrState pointerState; + PtrAnalysis::rewriteAdvanceOp(op, rewriter, knownPtrs); + return success(); + } +}; + +struct MakeTensorPtrConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + void populateVectorAsIndex(SmallVector &vec, + Operation::operand_range ops, + ConversionPatternRewriter &rewriter, + Location loc) const { + for (auto opnd : ops) { + if (isa(opnd.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), opnd); + vec.push_back(castOp.getResult()); + } else { + assert(isa(opnd.getType())); + vec.push_back(opnd); + } + } + } + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + PtrState pointerState; + + auto orderSize = op.getOrder().size(); + if (orderSize > 1) { + for (auto [first, second] : + llvm::zip(op.getOrder().slice(0, orderSize - 2), + op.getOrder().slice(1, orderSize - 1))) { + assert(first == second + 1 && + "Currently only support default order on block pointers"); + } + } + + pointerState.source = rewriter.getRemappedValue(op.getBase()); + populateVectorAsIndex(pointerState.offsets, op.getOffsets(), rewriter, loc); + populateVectorAsIndex(pointerState.strides, op.getStrides(), rewriter, loc); + + SmallVector newOffsets; + for (auto [offset, stride] : + llvm::zip(pointerState.offsets, pointerState.strides)) { + auto mulOp = rewriter.create(loc, cast(offset), + cast(stride)); + newOffsets.push_back(mulOp.getResult()); + } + + pointerState.offsets.clear(); + + for (auto offset : newOffsets) { + pointerState.offsets.push_back(offset); + } + + ArrayRef resultShape; + auto pointerType = + cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + for (auto dim_size : resultShape) { + pointerState.sizes.push_back( + IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); + } + } else { + // scalar pointer, should produce a one dimensional memref + SmallVector scalarShape(1, 1); + resultShape = scalarShape; + assert(pointerState.getRank() == 1); + } + + auto castOp = pointerState.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, castOp.getResult()); + return success(); + } +}; + +struct LegacyAddPtrConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrAnalysis::rewriteAddptrOp(op, rewriter, knownPtrs); + return success(); + } +}; + +struct LoadConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + void createSideBySideCopies(Value block1, Value block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + void createStackedCopies(Value block1, Value block2, Value dst, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + +public: + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ptr = adaptor.getPtr(); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + // 0. Shortcut for scalar loads + if (!isa(op.getResult().getType())) { + auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), adaptor.getPtr(), + loc, rewriter); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + auto loadOp = rewriter.create( + op.getLoc(), sMemRef, zeroMap, std::nullopt); + rewriter.replaceOp(op, loadOp.getResult()); + return success(); + } + + // 1. Simple case where no mask is used. + auto type = dyn_cast(ptr.getType()); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of operations. + // The workaround is to broadcast the pointers early in the address + // calculation. A proper fix is complicated, but at least we can provide a + // better error message. + return rewriter.notifyMatchFailure( + op, "LoadOp expects a memref, not a memref of pointers"); + } + + auto tensorType = + RankedTensorType::get(type.getShape(), type.getElementType()); + auto alloc = rewriter.create( + loc, MemRefType::get(type.getShape(), type.getElementType())); + + if (!mask) { + assert(!other && "other value used in non-masked load"); + if (auto unrealizedCast = + ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (wrapType.getValue() == ModuloState::WraparoundSideBySide) { + createSideBySideCopies(block1, block2, alloc, loc, rewriter); + } else if (wrapType.getValue() == ModuloState::WraparoundStacked) { + createStackedCopies(block1, block2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + rewriter.create(loc, ptr, alloc); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + + // 2. Continuous masked loads. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "Cannot lower continuous masked loads"); + } + + // fill load destination with other value + if (other) { + auto scalarOther = getScalarValue(other, loc, rewriter); + assert(scalarOther && "other value used in masked load produced by " + "unsupported instruction"); + + // For each dimension check if mstate.dims[i] < shape[i], or-accumulate + // the result + auto shape = type.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < type.getShape().size(); i++) { + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + + Value dimi = dyn_cast(mstate.dims[i]); + if (!dimi) { + dimi = rewriter.create( + loc, cast(cast(mstate.dims[i]))); + } + + auto cmpOp = rewriter.create( + loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = rewriter.create(loc, accBase, cmpOp.getResult()) + .getResult(); + } + + // condition the memset on the or-accumulation + // initialize with padding prior to CopyOp + rewriter.create( + loc, accBase, [&](OpBuilder &builder, Location loc) { + builder.create(loc, ValueRange{scalarOther}, + ValueRange{alloc}); + builder.create(loc); + }); + } + + if (auto unrealizedCast = ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (wrapType.getValue() == ModuloState::WraparoundSideBySide) { + auto [subview1, subview2] = + mstate.getSideBySideSubviews(block1, block2, loc, rewriter); + + createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); + } else if (wrapType.getValue() == ModuloState::WraparoundStacked) { + auto [subview1, subview2] = + mstate.getStackedSubviews(block1, block2, loc, rewriter); + + createStackedCopies(subview1, subview2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + memref::SubViewOp srcSubview = mstate.getSubview(ptr, loc, rewriter); + memref::SubViewOp dstSubview = mstate.getSubview(alloc, loc, rewriter); + rewriter.create(loc, srcSubview, dstSubview); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } +}; + +struct StoreConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ptr = adaptor.getPtr(); + auto val = adaptor.getValue(); + auto mask = op.getMask(); + auto loc = op.getLoc(); + + // 0. Shortcut for scalar stores + if (!isa(val.getType())) { + auto sMemRef = + PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + rewriter.create(loc, val, sMemRef, zeroMap, + std::nullopt); + rewriter.eraseOp(op); + return success(); + } + + // 1. Simple case where no mask is used. + if (!mask) { + auto storeOp = rewriter.create( + loc, val, ptr); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); + } + + // 2. Continuous masked stores. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) + return failure(); + + auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); + auto dstSubview = mstate.getSubview(ptr, loc, rewriter); + + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct LoopConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrAnalysis::IndexMapSet + levelToBlockArgIndex; // level -> set of block arg index to be replaced + + PtrAnalysis::rewriteForOp(op, rewriter, levelToBlockArgIndex, 0, knownPtrs); + return success(); + } +}; + +struct YieldConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +// Remove all Meta ops except for AddPtr which is handled by AddPtrConverter. +// Use benefit == 10 to ensure that this pattern always takes precedence over +// other patterns. +struct MetaOpConverter : public RewritePattern { +private: + // UseAnalysis will tag operations whose results are used only as meta-data + // with "MetaUse" tag. + bool isMetaUse(Operation *op) const { return op->hasAttr("MetaUse"); } + +public: + MetaOpConverter(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + + if (isa(op)) { + return rewriter.notifyMatchFailure(op, + "AddPtrOp will be handled separately"); + } + + if (isMetaUse(op)) { + rewriter.eraseOp(op); + return success(); + } + + return rewriter.notifyMatchFailure(op, "requires meta ops"); + } +}; + +struct UnrealizedCastConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +//----------------------------- +// End of monolithic only +//----------------------------- + +struct SplatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opType = cast(op.getType()); + auto loc = op.getLoc(); + + auto init = rewriter.create(loc, opType.getShape(), + opType.getElementType()); + + auto filledTensor = + rewriter + .create(loc, ValueRange{adaptor.getSrc()}, + ValueRange{init}) + .result(); + + rewriter.replaceOp(op, filledTensor); + return success(); + } +}; + +struct BroadcastConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst) const { + SmallVector broadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (dstShape[i] != srcShape[i]) { + assert(srcShape[i] == 1); + broadcastDims.push_back(i); + } + } + assert(!broadcastDims.empty() && "cannot identify broadcast dimension"); + return broadcastDims; + } + + // Broadcasts input tensor based on TosaToLinalg's broadcastToShape + AffineMap getBroadcastAffineMap(MLIRContext *context, + ArrayRef inputShape, + ArrayRef broadcastToShape) const { + + assert(broadcastToShape.size() >= inputShape.size()); + + // Create affine map and shapes for tensor initialization. + SmallVector outExpr; + + size_t diff = broadcastToShape.size() - inputShape.size(); + for (size_t i = 0; i < broadcastToShape.size(); i++) { + if (i < diff) { + continue; + } + size_t j = i - diff; + if (inputShape[j] == 1) { + // Broadcast singleton dimension + outExpr.push_back(mlir::getAffineConstantExpr(0, context)); + continue; + } + // Non-broadcast case + outExpr.push_back(mlir::getAffineDimExpr(i, context)); + } + return AffineMap::get(broadcastToShape.size(), 0, outExpr, context); + } + +public: + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + assert(op->getNumResults() == 1 && "code assumes single result!"); + RankedTensorType sourceType = + cast(adaptor.getSrc().getType()); + RankedTensorType resultType = cast(op.getType()); + auto elementType = resultType.getElementType(); + size_t resultRank = resultType.getRank(); + + SmallVector indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + indexingMaps.push_back(getBroadcastAffineMap( + op->getContext(), sourceType.getShape(), resultType.getShape())); + indexingMaps.append(op->getNumResults(), + rewriter.getMultiDimIdentityMap(resultRank)); + + assert(op->getNumResults() == 1 && "code assumes single result!"); + auto init = rewriter.create(loc, resultType.getShape(), + elementType); + + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = blockArgs[0]; + nestedBuilder.create(loc, opResult); + }); + + linalgOp->setAttr("broadcastDims", + rewriter.getDenseI64ArrayAttr( + getBroadcastDims(sourceType, resultType))); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + +struct ExpandDimsConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto src = adaptor.getSrc(); + auto srcRank = cast(src.getType()).getRank(); + auto resType = cast(op->getResultTypes()[0]); + SmallVector reassoc; + int64_t c = 0; + for (int64_t i = 0; i < srcRank; i++) { + ReassociationIndices g; + g.push_back(c++); + if (op.getAxis() == i) { + g.push_back(c++); + } else if (op.getAxis() == i + 1 && i == srcRank - 1) { + g.push_back(c++); + } + reassoc.push_back(g); + } + + auto expandShapeOp = rewriter.create( + op.getLoc(), resType, src, reassoc); + + rewriter.replaceOp(op, expandShapeOp.getResult()); + return success(); + } +}; + +struct TransposeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto src = adaptor.getSrc(); + auto srcRank = cast(src.getType()).getRank(); + assert(srcRank == 2 && "only expect transposing 2D data"); + + auto res = getTransposedValue(src, op.getLoc(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct MakeRangeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = cast(op.getResult().getType()); + auto shape = type.getShape(); + auto elementType = type.getElementType(); + auto context = rewriter.getContext(); + + assert(type.getShape().size() == 1 && + type.getElementType().getIntOrFloatBitWidth() == 32 && + "make range can only return 1D int32 tensor"); + + SmallVector indexingMaps{AffineMap::get( + /* dimCount */ 1, /* symbolCount */ 0, + SmallVector{mlir::getAffineDimExpr(0, context)}, context)}; + + auto init = rewriter.create(loc, shape, elementType); + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), /* operands */ ValueRange{}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(1), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value index = nestedBuilder.create(loc, 0); + Value res = nestedBuilder.create( + loc, type.getElementType(), index); + nestedBuilder.create(loc, res); + }); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + +struct AssertConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value condVal = op.getCondition(); + + if (isa(condVal.getType())) { + auto scalarVal = getScalarValue(op.getCondition(), op.getLoc(), rewriter); + condVal = scalarVal ? scalarVal : condVal; + } + assert(condVal && isa(condVal.getType()) && + "Only asserts on scalars are currently supported"); + + if (!condVal.getType().isInteger(1)) { + auto zero = + rewriter.create(op.getLoc(), 0, 32); + auto newCond = rewriter.create( + op.getLoc(), arith::CmpIPredicate::ne, condVal, zero); + condVal = newCond.getResult(); + } + + auto assertMessage = llvm::formatv("FIXME: assertion!"); + rewriter.create(op.getLoc(), condVal, + assertMessage.str()); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct BitcastConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto arithBitcast = rewriter.create( + op.getLoc(), op.getType(), op.getOperand()); + + rewriter.replaceOp(op, arithBitcast.getResult()); + return success(); + } +}; + +struct CallConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector args = adaptor.getOperands(); + + // We need to pass extra arguments added by addProgramInfo which are num_programs and program_ids + if (FuncOp parentFunc = op->getParentOfType()) { + SymbolRefAttr calleeAttr = op.getCalleeAttr(); + StringRef calleeName = calleeAttr.getRootReference(); + + if (ModuleOp module = op->getParentOfType()) { + if (FuncOp calleeFunc = module.lookupSymbol(calleeName)) { + size_t argsNeed = calleeFunc.getFunctionType().getInputs().size(); + Block &entryBlock = parentFunc.front(); + auto parentInputs = entryBlock.getArguments(); + size_t argsParent = parentInputs.size(); + + if (argsNeed > args.size()) { + int missing = argsNeed - args.size(); + for (int i = 0; i < missing; i++) { + args.push_back(parentInputs[args.size()]); + } + } + } + } + } + + auto call = rewriter.create( + op.getLoc(), op.getCallee(), op.getResultTypes(), args); + + if (!call) { + op.emitError("Failed to create func::CallOp"); + return failure(); + } + + rewriter.replaceOp(op, call); + return success(); + } +}; + +struct FpToFpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundingMode = triton::RoundingMode::RTNE; // default + + auto roundingModeAttr = op.getRounding(); + if (roundingModeAttr.has_value()) { + roundingMode = roundingModeAttr.value(); + } + + assert(roundingMode != triton::RoundingMode::RTZ && + "Rounding Towards Zero is not supported"); + + Type resultType = op.getResult().getType(); + + auto operandWidth = getBitWidth(op.getOperand().getType()); + auto resultWidth = getBitWidth(resultType); + + assert(operandWidth.has_value() && resultWidth.has_value() && + "Not a float-like operand or result"); + + if (operandWidth.value() > resultWidth.value()) { + Value truncatedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, truncatedValue); + return success(); + } + + Value extendedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + rewriter.replaceOp(op, extendedValue); + + return success(); + } +}; + +struct ClampConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL; + + assert(!propagateNan && + "PropagateNan is not supported"); + + Location loc = op.getLoc(); + Value x = adaptor.getOperands()[0]; + Value min = adaptor.getOperands()[1]; + Value max = adaptor.getOperands()[2]; + + Value maxMin = rewriter.create(loc, x, min); + Value clamp = rewriter.create(loc, maxMin, max); + rewriter.replaceOp(op, clamp); + + return success(); + } +}; + +struct PreciseSqrtConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + +struct PreciseDivConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + +struct CatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + + return success(); + } +}; + +struct SplitConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getOperand(); + auto inputType = cast(input.getType()); + + Type resultType = op.getResults().front().getType(); + auto resultTensor = cast(resultType); + auto shape = inputType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = + llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + SmallVector results; + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + Value slice = rewriter.create( + loc, resultTensor, input, offsets, sizes, strides); + results.push_back(slice); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct JoinConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange inputs = op.getOperands(); + + auto resultType = cast(op.getResult().getType()); + + auto loc = op.getLoc(); + Value result = rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + + auto shape = resultType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = + llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + result = rewriter.create(loc, inputs[i], result, offsets, sizes, strides); + } + + rewriter.replaceOp(op, result); + + return success(); + } +}; + +struct MulHiUIOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto mulResult = rewriter.create(loc, adaptor.getOperands()); + rewriter.replaceOp(op, mulResult.getHigh()); + + return success(); + } +}; + +// TODO: Move this MatmulConverter to MK related folder as it converts +// triton::DotOp directly into mk::DotOp which carries more information than +// linalg.matmul. +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // true means tensor elements are zeros + // false means not zero or it cannot be determined + bool isZeroTensor(Value &v, bool integers) const { + if (auto splatOp = v.getDefiningOp()) { + if (auto constOp = splatOp.getSrc().getDefiningOp()) { + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValueAsDouble() == 0.; + } + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValue() == 0; + } + } + return false; + } + + if (auto constOp = v.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat()) { + if (integers) + return denseAttr.getSplatValue().isZero(); + return denseAttr.getSplatValue().isZero(); + } + } + } + + return false; + } + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto opa = op.getA(); + auto opb = op.getB(); + auto opc = op.getC(); + + auto dstType = cast(op.getType()); + auto elementType = dstType.getElementType(); + bool integers = elementType.isInteger(); + + auto init = + rewriter.create(loc, dstType.getShape(), elementType); + TypedAttr constantAttr = integers ? + static_cast(rewriter.getIntegerAttr(elementType, 0)) : + static_cast(rewriter.getFloatAttr(elementType, 0)); + + auto zero = rewriter.create( + op.getLoc(), elementType, constantAttr); + + auto zeroes = + rewriter.create(loc, ValueRange{zero}, ValueRange{init}) + .result(); + + auto dotOp = rewriter.create( + loc, dstType, ValueRange{opa, opb, opc, zeroes}); + + rewriter.replaceOp(op, dotOp); + + return success(); + } +}; + +struct ReduceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + llvm::SmallVector getRedOps(triton::ReduceOp redOp) const { + auto reduceBlock = redOp.getBody(); + return llvm::map_to_vector(reduceBlock->without_terminator(), + [](Operation &op) { return &op; }); + } + + bool isReductionOpSupported(Operation *redOp) const { + return isa( + redOp); + } + + arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, + Type constantType) const { + const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); + + auto attr = + llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::AddIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .Case([&](auto) { + return rewriter.getFloatAttr( + constantType, std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, + attr); + } + + bool requiresF32Conversion(const Type elemType, Operation *redOp) const { + return isa(elemType) && + elemType.getIntOrFloatBitWidth() < + llvm::cast(Float32Type::get(elemType.getContext())) + .getWidth() && + isa(redOp); + } + + Value getRedElement(Value lhs, Value rhs, const Location loc, + Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const { + return llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + if (convertLhsToF32Precision) { + lhs = b.create(loc, Float32Type::get(b.getContext()), + lhs); + } + return b.create(loc, lhs, rhs); + }) + .Case([&](auto redOp) { + return b.create(loc, lhs, rhs); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + } + + LogicalResult + convertToLinalgReduce(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto source = adaptor.getOperands().front(); + auto sourceType = cast(source.getType()); + auto elemType = sourceType.getElementType(); + auto resType = op.getResult().front().getType(); + auto loc = op.getLoc(); + auto reductionOps = getRedOps(op); + + // Reduction of arbitrary operations isn't supported because using the first + // element across the reduction dimension requires us to iterate over a + // subview that skips over each first element. + if (reductionOps.size() != 1 || + !isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with body " + "containing 1 max(i/f) or addf."); + } + + auto rop = reductionOps.front(); + auto axis = op.getAxis(); + auto isVectorReduce = sourceType.getRank() == 1; + + if (axis == sourceType.getRank() - 1 && !isVectorReduce) { + source = getTransposedValue(source, op.getLoc(), rewriter); + axis = sourceType.getRank() - 2; + } + + bool convertToF32Precision = requiresF32Conversion(resType, rop); + + auto constantType = convertToF32Precision + ? Float32Type::get(rewriter.getContext()) + : elemType; + + auto accBaseConstOp = getRedBaseConstOp(rewriter, rop, constantType); + Value initTensor; + + if (isVectorReduce) { + // The affine vectorizer cannot vectorize affine loops generated from + // linalg.reduce for the vector reduce case, so we must rewrite the + // linalg.reduce to affine loops manually. Here we lower to AllocTensor + // directly instead of EmptyOp so that the subsequent pass can recognize + // the patterns (EmptyOp is susceptible to being CSE'd away, making it + // harder to match the patterns correctly). + initTensor = rewriter.create( + loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = rewriter.create(loc, accBaseConstOp, + initTensor, ValueRange{}); + } else { + Value init = rewriter.create( + loc, cast(resType).getShape(), constantType); + initTensor = rewriter + .create(loc, ValueRange{accBaseConstOp}, + ValueRange{init}) + .result(); + } + + Value finalResult = + rewriter + .create( + loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = + getRedElement(inputs[0], inputs[1], loc, rop, opBuilder, + convertToF32Precision); + opBuilder.create(loc, result); + }) + .getResult(0); + + if (sourceType.getRank() == 1) { + finalResult = + rewriter.create(loc, constantType, finalResult); + } + + if (convertToF32Precision) { + finalResult = rewriter.create(loc, resType, finalResult); + } + + rewriter.replaceOp(op, finalResult); + return success(); + } + +public: + LogicalResult + matchAndRewrite(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = + cast(adaptor.getOperands().front().getType()); + assert(sourceType.hasRank() && "Expected input is " + "ranked"); + + int64_t axis = op.getAxis(); + assert(axis >= 0 && axis < sourceType.getRank() && + "Expected reduction " + "axis is within " + "operand's rank"); + + return convertToLinalgReduce(op, adaptor, rewriter); + } +}; + +template +class ArgMinMaxBaseConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // We're looking for an op that looks like this: + // + // %9:2 = "tt.reduce"(%8, %3) <{axis = 0 : i32}> ({ + // ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + // ------------------------------------------------- + // `matchTieBreakValue` | + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 | + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 | 1. + // %13 = arith.andi %11, %12 : i1 | + // ------------------------------------------------- |-> `matchShouldUpdate` + // `matchUpdateCondition` | + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 | 2. + // ------------------------------------------------- | + // %15 = arith.ori %14, %13 : i1 | + // ------------------------------------------------- + // %16 = arith.select %15, %arg9, %arg11 : f32 + // %17 = arith.select %15, %arg10, %arg12 : i32 + // tt.reduce.return %16, %17 : f32, i32 + // }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + // + // The above mlir code is lowered from this combinator in triton's + // standard.py: + // + // def _argmax_combine(value1, index1, value2, index2, tie_break_left): + // if tie_break_left: + // tie = value1 == value2 and index1 < index2 + // else: + // tie = False + // gt = value1 > value2 or tie + // v_ret = core.where(gt, value1, value2) + // i_ret = core.where(gt, index1, index2) + // return v_ret, i_ret + + LogicalResult matchTieBreakResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &tileBreakValue) const { + // Match the following (section 1. of the above) + // + // %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + // %12 = arith.cmpi slt, %arg10, %arg12 : i32 + // %13 = arith.andi %11, %12 : i1 + // + // which is equivalent to the following python code + // + // tie = value1 == value2 and index1 < index2 + + // matching: %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto eqCmpOp = dyn_cast(*it++); + if (eqCmpOp) { + if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ) { + return failure(); + } + if (currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + // matching: %12 = arith.cmpi slt, %arg10, %arg12 : i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto sltCmpOp = dyn_cast(*it++); + if (sltCmpOp) { + if (sltCmpOp.getPredicate() != arith::CmpIPredicate::slt) { + return failure(); + } + if (currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + // matching: %13 = arith.andi %11, %12 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto andOp = dyn_cast(*it++); + if (andOp) { + if (andOp.getLhs() != eqCmpOp || andOp.getRhs() != sltCmpOp) { + return failure(); + } + } else { + return failure(); + } + + tileBreakValue = andOp; + return success(); + } + + LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &shouldUpdate) const { + Value tieResult; + if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, + reduceIndex, it, tieResult))) { + LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); + return failure(); + } + + Value comparisonResult; + if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, + reduceIndex, it, comparisonResult))) { + LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); + return failure(); + } + + // matching: %15 = arith.ori %14, %13 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto orOp = dyn_cast(*it++); + if (orOp) { + if (orOp.getLhs() != comparisonResult || orOp.getRhs() != tieResult) { + return failure(); + } + } else { + return failure(); + } + + shouldUpdate = orOp; + return success(); + } + + Value getInitTensor(ConversionPatternRewriter &rewriter, + ArrayRef shape, Value fillValue, + Location loc) const { + Value initTensor = + rewriter.create(loc, shape, fillValue.getType()); + return rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{initTensor}) + .result(); + } + +public: + ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult match(ReduceOp op) const override final { + if (op.getBody()->getNumArguments() != 4) { + return failure(); + } + + auto block = op.getBody(); + auto ops = block->without_terminator(); + + Value currValue = block->getArgument(0); + Value currIndex = block->getArgument(1); + Value reduceValue = block->getArgument(2); + Value reduceIndex = block->getArgument(3); + + auto opsIt = ops.begin(); + Value shouldUpdate; + if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, + reduceIndex, opsIt, shouldUpdate))) { + return failure(); + } + + // matching: %16 = arith.select %15, %arg9, %arg11 : f32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto valueSelectOp = dyn_cast(*opsIt++); + if (valueSelectOp) { + if (valueSelectOp.getCondition() != shouldUpdate || + currValue != valueSelectOp.getTrueValue() || + reduceValue != valueSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + + // matching:%17 = arith.select %15, %arg10, %arg12 : i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto indexSelectOp = dyn_cast(*opsIt++); + if (indexSelectOp) { + if (indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + + // matching: tt.reduce.return %16, %17 : f32, i32 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto termOp = dyn_cast(*opsIt++); + if (termOp && termOp == block->getTerminator()) { + auto opnds = termOp.getOperands(); + if (opnds != ArrayRef{valueSelectOp, indexSelectOp}) { + return failure(); + } + } else { + return failure(); + } + + return success(); + } + + void rewrite(ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto loc = op.getLoc(); + + auto elemTypes = op.getElementTypes(); + + // Set the initial value of the rank-0 tensor containing + // the result value to either -inf or +inf depending on + // whether we're dealing with argmax or argmin + auto valueType = elemTypes[0]; + auto valuesAccBaseVal = rewriter.create( + loc, valueType, + rewriter.getFloatAttr(valueType, T::getBaseReductionValue())); + + // Set the initial value of the rank-0 tensor containing the index of the + // min or max value to -1 + auto indexType = elemTypes[1]; + auto indicesAccBaseVal = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + + // Get the shape of the resulting tensors (both for values and indices). If + // we are reducing to a single scalar, then the result's type is a tensor of + // rank-0, otherwise we can reuse the original result shape + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + SmallVector reductionResultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(valueResultType.getShape())}; + + SmallVector outputs{ + getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), + getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + assert(inputs.size() == 4); + + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { + return mapping.lookup(val); + }); + b.create(loc, results); + }); + + if (isScalarReduce) { + SmallVector reduceResults{ + rewriter.create( + loc, valueType, linalgOp.getResults()[0], ValueRange{}), + rewriter.create( + loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + } +}; + +struct ArgMaxConverter : public ArgMinMaxBaseConverter { + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult) { + // %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + // This corresponds to section 2. of the sample snippet in + // ArgMinMaxBaseConverter + auto cmpOp = dyn_cast(*it++); + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + comparisonResult = cmpOp; + return success(); + } + + static float getBaseReductionValue() { + return -std::numeric_limits::infinity(); + } + + ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +struct ArgMinConverter : public ArgMinMaxBaseConverter { + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult) { + // %14 = arith.cmpf olt, %arg9, %arg11 : f32 + // This corresponds to section 2. of the sample snippet in + // ArgMinMaxBaseConverter + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto cmpOp = dyn_cast(*it++); + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + } else { + return failure(); + } + + comparisonResult = cmpOp; + return success(); + } + + static float getBaseReductionValue() { + return std::numeric_limits::infinity(); + } + + ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +// get_program_id and get_num_programs: +// When launching triton kernels, we pass 6 additional arguments to indicate +// num_programs and program_id. Amongst those six, we have 3 arguments +// correspond to each axis for num_programs followed by 3 additional arguments +// for program_id. +// +// For instance, with triton kernel example_kernel(a, b, c), we have: +// example_kernel( +// a, b, c, +// num_programs_axis_0, +// num_programs_axis_1, +// num_programs_axis_2, +// program_id_axis_0, +// program_id_axis_1, +// program_id_axis_2, +// ) +// +struct GetProgramIDConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto axis = (uint32_t)op.getAxis(); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - LAUNCH_GRID_RANK + axis); + + rewriter.replaceOp(op, id); + return success(); + } +}; + +struct GetNumProgramsConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + GetNumProgramsConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto axis = (uint32_t)op.getAxis(); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - LAUNCH_GRID_RANK * 2 + axis); + + rewriter.replaceOp(op, id); + return success(); + } +}; + +// Convert a pair of cmpf and select to either min or max. +// Leave the pattern as simple as possible because triton has plans to emit +// min and max directly. +template +struct MinMaxConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + MinMaxConverter(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/10) {} + + LogicalResult matchAndRewrite(CmpOp cmpOp, + PatternRewriter &rewriter) const final { + if (!cmpOp.getResult().hasOneUse()) { + return failure(); + } + auto selectOp = + dyn_cast(*cmpOp.getResult().getUsers().begin()); + if (!selectOp) { + return failure(); + } + + if (!(cmpOp.getResult() == selectOp.getCondition() && + cmpOp.getLhs() == selectOp.getTrueValue() && + cmpOp.getRhs() == selectOp.getFalseValue())) { + return failure(); + } + + rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); + rewriter.eraseOp(cmpOp); + + return success(); + } + + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, + arith::SelectOp selectOp, + arith::CmpFPredicate pred) const { + switch (pred) { + case arith::CmpFPredicate::OGT: + case arith::CmpFPredicate::OGE: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpFPredicate::OLT: + case arith::CmpFPredicate::OLE: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + default: + llvm_unreachable("Unhandled predicate"); + } + } + + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp, + arith::SelectOp selectOp, + arith::CmpIPredicate pred) const { + switch (pred) { + case arith::CmpIPredicate::sgt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ugt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::slt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ult: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + default: + llvm_unreachable("Unhandled predicate"); + } + } +}; + +struct DenseConstantConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto attr = cast(op.getValue()); + auto loc = op.getLoc(); + + auto splatConst = arith::ConstantOp::materialize( + rewriter, attr.getSplatValue(), attr.getElementType(), loc); + + auto init = rewriter.create( + loc, cast(op.getResult().getType()).getShape(), + attr.getElementType()); + + rewriter.replaceOpWithNewOp(op, ValueRange{splatConst}, + ValueRange{init}); + + return success(); + } +}; + +class CumSumConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // CumSum is a specific instance of Scan that looks like the following: + // %1 = "tt.scan"(%0) <{axis = 1 : i32}> ({ + // ^bb0(%arg0: f32, %arg1: f32): + // %2 = arith.addf %arg0, %arg1 : f32 + // tt.scan.return %2 : f32 + // }) : (tensor<4x4xf32>) -> tensor<4x4xf32> + bool isCumSum(triton::ScanOp op) const { + auto scanBlock = op.getBody(); + auto ops = llvm::map_to_vector(scanBlock->without_terminator(), + [](Operation &op) { return &op; }); + + if (ops.size() != 1) { + return false; + } + + auto addOp = ops.front(); + if (isa(addOp)) { + if (addOp->getResult(0) != scanBlock->getTerminator()->getOperand(0)) { + return false; + } + + auto blockArgs = + llvm::map_range(scanBlock->getArguments(), [](BlockArgument arg) { + return dyn_cast(arg); + }); + + auto addArgs = addOp->getOperands(); + + return DenseSet(blockArgs.begin(), blockArgs.end()) == + DenseSet(addArgs.begin(), addArgs.end()); + } + + return false; + } + +public: + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isCumSum(op)) { + return rewriter.notifyMatchFailure( + op, "Only support cumsum variant of scan op"); + } + + auto input = op.getOperand(0); + auto axis = op.getAxis(); + auto type = dyn_cast(input.getType()); + + if (type.getRank() != 1 && type.getRank() != 2 && + axis != type.getRank() - 1) { + return rewriter.notifyMatchFailure( + op, "Only support lowering scan op to cumsum with rank " + "= {1, 2} and axis = rank - 1"); + } + + Value init = rewriter.create(op.getLoc(), type.getShape(), + type.getElementType()); + + rewriter.replaceOpWithNewOp( + op, input, rewriter.getUI32IntegerAttr(axis), init); + + return success(); + } +}; + +class AddPtrConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resType = op.getResult().getType(); + assert(isa(resType)); + auto rank = cast(resType).getRank(); + SmallVector indexingMaps( + /*numResult + numOperands*/ 3, rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + SmallVector outputs = {op.getPtr()}; + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps, + iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return cast(type).getElementType(); + })); + auto *scalarOp = + builder.create(loc, op->getName().getIdentifier(), + regionArgs.take_front(op->getNumOperands()), + resultTypes, op->getAttrs()); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; + +class ReshapeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getSrc(); + auto output = op.getResult(); + + auto inputType = input.getType(); + auto outputType = output.getType(); + if (!outputType.hasStaticShape()) { + return failure(); + } + + if (auto maybeReassociationMap = + getReassociationIndicesForReshape(inputType, outputType)) { + auto reassociationMap = *maybeReassociationMap; + if (outputType.getRank() < inputType.getRank()) { + rewriter.replaceOpWithNewOp( + op, outputType, input, reassociationMap); + } else { + rewriter.replaceOpWithNewOp( + op, outputType, input, reassociationMap); + } + return success(); + } + + ArrayRef outputShape = outputType.getShape(); + + auto shape = rewriter.create( + loc, rewriter.getI64TensorAttr(outputShape)); + rewriter.replaceOpWithNewOp(op, outputType, input, + shape); + + return success(); + } +}; + +class ExternElementwiseBinaryOpConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + if (!op.getPure() || op.getSrcs().size() != 2) + return failure(); +#define POPULATE_BINARY_OP(FUNC_NAME, DST_OP) \ + if (!op.getSymbol().compare(FUNC_NAME)) { \ + rewriter.replaceOpWithNewOp(op, op.getSrcs()[0], op.getSrcs()[1]); \ + return success(); \ + } + + POPULATE_BINARY_OP("__nv_atan2f", math::Atan2Op); + POPULATE_BINARY_OP("__nv_atan2", math::Atan2Op); + POPULATE_BINARY_OP("__nv_powf", math::PowFOp); + POPULATE_BINARY_OP("__nv_pow", math::PowFOp); + +#undef POPULATE_BINARY_OP + return failure(); + } +}; + +class ExternElementwiseUnaryOpConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + if (!op.getPure() || op.getSrcs().size() != 1) + return failure(); +#define POPULATE_UNARY_OP(FUNC_NAME, DST_OP) \ + if (!op.getSymbol().compare(FUNC_NAME)) { \ + rewriter.replaceOpWithNewOp(op, op.getSrcs()[0]); \ + return success(); \ + } + + POPULATE_UNARY_OP("__nv_fabsf", math::AbsFOp); + POPULATE_UNARY_OP("__nv_fabs", math::AbsFOp); + POPULATE_UNARY_OP("__nv_sinf", math::SinOp); + POPULATE_UNARY_OP("__nv_sin", math::SinOp); + POPULATE_UNARY_OP("__nv_cosf", math::CosOp); + POPULATE_UNARY_OP("__nv_cos", math::CosOp); + POPULATE_UNARY_OP("__nv_tanf", math::TanOp); + POPULATE_UNARY_OP("__nv_tan", math::TanOp); + POPULATE_UNARY_OP("__nv_asinf", math::AsinOp); + POPULATE_UNARY_OP("__nv_asin", math::AsinOp); + POPULATE_UNARY_OP("__nv_acosf", math::AcosOp); + POPULATE_UNARY_OP("__nv_acos", math::AcosOp); + POPULATE_UNARY_OP("__nv_atanf", math::AtanOp); + POPULATE_UNARY_OP("__nv_atan", math::AtanOp); + POPULATE_UNARY_OP("__nv_sinhf", math::SinhOp); + POPULATE_UNARY_OP("__nv_sinh", math::SinhOp); + POPULATE_UNARY_OP("__nv_coshf", math::CoshOp); + POPULATE_UNARY_OP("__nv_cosh", math::CoshOp); + POPULATE_UNARY_OP("__nv_tanhf", math::TanhOp); + POPULATE_UNARY_OP("__nv_tanhf", math::TanhOp); + POPULATE_UNARY_OP("__nv_acoshf", math::AcoshOp); + POPULATE_UNARY_OP("__nv_acosh", math::AcoshOp); + POPULATE_UNARY_OP("__nv_asinhf", math::AsinhOp); + POPULATE_UNARY_OP("__nv_asinh", math::AsinhOp); + POPULATE_UNARY_OP("__nv_atanhf", math::AtanhOp); + POPULATE_UNARY_OP("__nv_atanhf", math::AtanhOp); + POPULATE_UNARY_OP("__nv_logf", math::LogOp); + POPULATE_UNARY_OP("__nv_log", math::LogOp); + POPULATE_UNARY_OP("__nv_log10f", math::Log10Op); + POPULATE_UNARY_OP("__nv_log10", math::Log10Op); + POPULATE_UNARY_OP("__nv_log1pf", math::Log1pOp); + POPULATE_UNARY_OP("__nv_log1p", math::Log1pOp); + POPULATE_UNARY_OP("__nv_expf", math::ExpOp); + POPULATE_UNARY_OP("__nv_exp", math::ExpOp); + POPULATE_UNARY_OP("__nv_exp2f", math::Exp2Op); + POPULATE_UNARY_OP("__nv_exp2", math::Exp2Op); + POPULATE_UNARY_OP("__nv_erff", math::ErfOp); + POPULATE_UNARY_OP("__nv_erf", math::ErfOp); + POPULATE_UNARY_OP("__nv_sqrtf", math::SqrtOp); + POPULATE_UNARY_OP("__nv_sqrt", math::SqrtOp); + POPULATE_UNARY_OP("__nv_rsqrtf", math::RsqrtOp); + POPULATE_UNARY_OP("__nv_rsqrt", math::RsqrtOp); + POPULATE_UNARY_OP("__nv_ceilf", math::CeilOp); + POPULATE_UNARY_OP("__nv_ceil", math::CeilOp); + POPULATE_UNARY_OP("__nv_floorf", math::FloorOp); + POPULATE_UNARY_OP("__nv_floor", math::FloorOp); + POPULATE_UNARY_OP("__nv_truncf", math::TruncOp); + POPULATE_UNARY_OP("__nv_trunc", math::TruncOp); + +#undef POPULATE_UNARY_OP + return failure(); + } +}; + +static void populateExternElementwiseOpToMLIROps(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h new file mode 100644 index 000000000..b95cbde73 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td new file mode 100644 index 000000000..590c02de7 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/Passes.td @@ -0,0 +1,20 @@ +#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES +#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonArithToLinalg : Pass<"triton-arith-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton arithmetic operations into linalg"; + let options = [ + Option<"pidsToFuncArgs", "pids-to-func-args", "bool", /*default*/"true", + "Convert tt.get_program_id and tt.get_num_programs to reference to function arguments">, + Option<"ttToFuncFunc", "tt-to-func-func", "bool", /*default*/"true", + "Convert tt.func to func.func">, + Option<"addptrToLinalg", "addptr-to-linalg", "bool", /*default*/"true", + "Convert tt.addptr on tensors to linalg">, + Option<"assertToCf", "assert-to-cf", "bool", /*default*/"true", + "Convert tt.assert to cf.assert">, + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h new file mode 100644 index 000000000..8e5e5822a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h @@ -0,0 +1,28 @@ +#ifndef TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H +#define TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +void populateTritonArithToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTritonArithToLinalgConversionPatterns(bool pidsToFuncArgs, + bool addptrToLinalg, + bool assertToCf, + RewritePatternSet &patterns); + +std::unique_ptr> createTritonArithToLinalgPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt new file mode 100644 index 000000000..3cc51fcb2 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/CMakeLists.txt @@ -0,0 +1,9 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToCoreDialects) +add_public_tablegen_target(TritonToCoreDialectsConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h new file mode 100644 index 000000000..32fc0104d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES_H +#define TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td new file mode 100644 index 000000000..6d10cfb6f --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/Passes.td @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES +#define TRITON_TO_CORE_DIALECTS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToCoreDialects : Pass<"triton-to-core-dialects", "mlir::ModuleOp"> { + let summary = "Convert Triton to core dialects including Linalg, Memref etc"; + let constructor = "triton::createTritonToCoreDialectsPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h new file mode 100644 index 000000000..d968cc055 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// +// +// This pass is the wrapall pass that populates all the conversion patterns from +// triton to core dialects such as linalg, memref, buf etc. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H +#define TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToCoreDialectsPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_TO_CORE_DIALECTS_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..74ccdd390 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,9 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) +add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h new file mode 100644 index 000000000..404af0802 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_TO_LINALG_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td new file mode 100644 index 000000000..627077e3a --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/Passes.td @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES +#define TRITON_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton to Linalg dialect"; + let constructor = "triton::createTritonToLinalgPass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h new file mode 100644 index 000000000..4c58e9921 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H +#define TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToLinalgPass(); + +void populateTritonToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + unsigned int launchGridRank); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 000000000..5762c1f69 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToStructured) +add_public_tablegen_target(TritonToStructuredConversionPassIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h new file mode 100644 index 000000000..3c3b81ca4 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES_H +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td new file mode 100644 index 000000000..e2702464b --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -0,0 +1,19 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { + let summary = "Convert Triton non-block pointer to TritonStructured dialect"; + let constructor = "triton::createTritonToStructuredPass()"; + let options = [ + Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false", + "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">, + Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false", + "Skip the prepass">, + Option<"useUnsafeMask", "use-unsafe-mask", "bool", /*default*/"false", + "Assume that the mask bounds are never less than starting offsets. May produce incorrect results."> + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h new file mode 100644 index 000000000..0ee1a6d53 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H +#define TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToStructuredPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt new file mode 100644 index 000000000..68066ab63 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 000000000..9c32c97c8 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TritonStructuredDialect.td) +mlir_tablegen(TritonStructuredDialect.h.inc -gen-dialect-decls -dialect=tts) +mlir_tablegen(TritonStructuredDialect.cpp.inc -gen-dialect-defs -dialect=tts) +mlir_tablegen(TritonStructuredOps.h.inc -gen-op-decls) +mlir_tablegen(TritonStructuredOps.cpp.inc -gen-op-defs) + + +add_public_tablegen_target(TritonStructuredTableGen) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h new file mode 100644 index 000000000..bd01afd05 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h @@ -0,0 +1,27 @@ +#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ +#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// TritonStructured Operations +//===----------------------------------------------------------------------===// +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td new file mode 100644 index 000000000..c0f89bfc1 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -0,0 +1,213 @@ +#ifndef TRITON_STRUCTURED_DIALECT +#define TRITON_STRUCTURED_DIALECT + +include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Triton_Structured_Dialect : Dialect { + let name = "tts"; + + let cppNamespace = "::mlir::tts"; + + let summary = "Structured Triton operations"; + + let description = [{ + Triton Structured Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect" + ]; + + let usePropertiesForAttributes = 1; +} + +// +// Op Base +// +class TTS_Op traits = []> : + Op { +} + +def TTS_MakeTensorPtrOp + : TTS_Op<"make_tptr", [AttrSizedOperandSegments, Pure]> { + let summary = "create a pointer that points to a tensor in memory"; + + // base: Base pointer used to contruct the tensor of pointers or pointer to tensor. + // sizes: Size of the data being loaded or stored. + // strides: The strides of the parent tensor, which means how much to increase the pointer + // by when moving by 1 element in a specific axis. + // order: The order of the block, which means how the block is laid out in memory. + // It contains the same info as order in tt.make_tensor_ptr. + // shape: If order is present, this field signifies the shape of the parent tensor in + // memory; if order is not present, it signifies the boundary by which addresses + // wraps around (constant zero indicates no wrap-around in the corresponding dimension). + // offsets: Offset of the block along each dimension from base. + // result: If order is present, this op produces a pointer to a tensor; otherwise, + // it produces a tensor of pointers. + + let arguments = (ins TT_Ptr:$base, + DenseI64ArrayAttr:$sizes, + Variadic:$strides, + Variadic:$offsets, + Variadic:$shape, + DenseI64ArrayAttr:$static_strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_shape, + DenseI32ArrayAttr:$order); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = [{ + $base `to` `sizes` `` `:` $sizes + `` `,` `strides` `` `:` + custom($strides, $static_strides) + `` `,` `offsets` `` `:` + custom($offsets, $static_offsets) + `` `,` `shape` `` `:` + custom($shape, $static_shape) + `` `,` `order` `` `:` $order + attr-dict `:` type($base) `to` type($result) + }]; + + + let builders = [ + // Build with mixed static and dynamic entries. + OpBuilder<(ins + "Value":$base, + "ArrayRef":$sizes, + "ArrayRef":$strides, + "ArrayRef":$offsets, + "ArrayRef":$shape, + "ArrayRef":$order)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedSizes() { + Builder b(getContext()); + SmallVector dynSizes; // sizes are always static + return ::mlir::getMixedValues(getSizes(), dynSizes, b); + } + SmallVector getMixedStrides() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticStrides(), getStrides(), b); + } + SmallVector getMixedOffsets() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticOffsets(), getOffsets(), b); + } + SmallVector getMixedShape() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticShape(), getShape(), b); + } + bool isBlockPtr() { + return !getOrder().empty(); + } + bool isStructuredPtr() { + return !isBlockPtr() && + llvm::all_of(getStaticShape(), [](auto shape) { return shape == 0; }); + } + bool isSplitPtr() { + return !isBlockPtr() && + !isStructuredPtr(); + } + }]; + + // TODO + //let hasVerifier = 1; + //let hasCanonicalizer = 1; +} + +// SameVariadicResultSize +// AttrSizedResultSegments +def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> { + let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; + let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; + + let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input); + let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic:$offsets, Variadic:$strides); + + let builders = [ + OpBuilder<(ins "Value":$input)>, + ]; + + let extraClassDeclaration = [{ + static std::optional, SmallVector>> + getOffsetAndStrideTypes(MLIRContext *context, Type ptrLikeType); + + static std::optional> + getOffsetAndStrideSegmentSizes(Type ptrLikeType); + }]; + + let hasFolder = 0; + let hasVerifier = 1; +} + +def TTS_LoadOp : TTS_Op<"load", [ + MemoryEffects<[MemRead]>, + AttrSizedOperandSegments +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrLike:$ptr, + Variadic:$mask_dims, + DenseI64ArrayAttr:$static_mask_dims, + Optional>:$other); + + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "ArrayRef":$mask_dims, "Value":$other)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedMaskDims() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticMaskDims(), getMaskDims(), b); + } + + bool hasMask() { + return !getStaticMaskDims().empty(); + } + }]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +def TTS_StoreOp : TTS_Op<"store", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrLike:$ptr, + TT_Tensor:$value, + Variadic:$mask_dims, + DenseI64ArrayAttr:$static_mask_dims); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$dims)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedMaskDims() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticMaskDims(), getMaskDims(), b); + } + + bool hasMask() { + return !getStaticMaskDims().empty(); + } + }]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +#endif // TRITON_STRUCTURED_DIALECT diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt new file mode 100644 index 000000000..ba67b25a7 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS TritonTilingExtOps.td) +mlir_tablegen(TritonTilingExtOpsDialect.h.inc -gen-dialect-decls -dialect=ttx) +mlir_tablegen(TritonTilingExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=ttx) +mlir_tablegen(TritonTilingExtOps.h.inc -gen-op-decls) +mlir_tablegen(TritonTilingExtOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(TritonTilingExtOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonTilingExtInterfaces.td) +mlir_tablegen(TritonTilingExtInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TritonTilingExtInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(TritonTilingExtInterfacesIncGen) diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h new file mode 100644 index 000000000..53e031db3 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ +#define MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" + +//===----------------------------------------------------------------------===// +// TritonTilingExt Operations +//===----------------------------------------------------------------------===// + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.h.inc" + +// Include the generated interface declarations. +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonTilingExt operations. +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.h.inc" + +namespace mlir { + +namespace ttx { + +// ----------------------------------------------------------------------------- +// BufferizableOpInterface +// ----------------------------------------------------------------------------- +// All TritonTilingExtOps need to support bufferization: the process of +// allocating buffers for tensors, thereby converting inputs and outputs of +// tensor type to memref. This process is done by implementing the +// "BufferizableOpInterface". We implement the interface for TritonTilingExtOps +// through an external model instead of directly in TritonTilingExtOps.td to be +// consistent with other ops in the mlir project. See some examples here: +// - mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +// - mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +// ----------------------------------------------------------------------------- +// TilingInterface +// ----------------------------------------------------------------------------- +// The three methods `getTiledImplementation`, `getResultTilePosition`, and +// `generateResultTileValue` are implemented as part of the TilingInterface. +// (see TilingInterface.td). These three methods are re-used across +// all TritonTilingExtOps, while others method are implemented individually by +// each operator depending on their use cases. +template +FailureOr getTiledImplementation(TritonTilingExtOpTy op, + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes); + +template +LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes); + +template +FailureOr +generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes); + +// ----------------------------------------------------------------------------- +// MemoryEffectsOpInterface +// ----------------------------------------------------------------------------- +// Implementation of the MemoryEffectsOpInterface for TritonTilingExtOps. +// This allows DCE pass to determine if a TritonTilingExtOp is safe to be +// removed. see TritonTilingExtOps.td for more details. +template +void getEffects( + TritonTilingExtOpTy op, + SmallVectorImpl> + &effects); + +// ----------------------------------------------------------------------------- +// Utilities +// ----------------------------------------------------------------------------- +// Utility method to extract a slice from the input source using either +// tensor::ExtractSlice or memref::SubView +Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides); + +} // namespace ttx +} // namespace mlir + +#endif // MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td new file mode 100644 index 000000000..e74fbb6cc --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES +#define MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES + +include "mlir/IR/OpBase.td" + +// +// Linalg operators require providing affine maps that define how input / output +// buffers are accessed together with a region that defines how each output +// element is computed; this requirement doesn't work well for operations such as +// `scan`. +// +// Fortunately, the introduction of the TilingInterface allows us to add tiling +// and fusion support to operations that don't fit into the linalg dialect. +// This fits our purpose perfectly: our `scan` operators can be treated as an +// "opaque" / "completely abstract" operation that can be tiled on the batch +// dimensions -- we don't need to provide any associated body together with it. +// +// However, this doesn't mean that we entirely forgo the "indexing map" concept. +// For example, consider the following: +// +// - ttx.scan ins(%1 : tensor<128x768xbf16>) +// outs(%2 : tensor<128x768xbf16>) -> tensor<128x768xbf16> +// +// Tiling the batch dimension gives us: +// +// for (i = 0 to 128) { +// %sliceIn = extract slice from input: tensor<1x768xbf16> +// %sliceOut = extract slice from output: tensor<1x768xbf16> +// %res = ttx.scan ins(slice : tensor<1x768xbf16>) +// outs(%2 : tensor<1x768xbf16>) -> tensor<1x768xbf16> +// insert %res into output +// } +// +// Now our `scan` op has the semantic of running `scan` on a rank-1 tensor and +// can be lowered further to other hardware-specific ops or external library +// calls. +// +// This tiling pattern is essentially the same as tiling a linalg.generic op +// with an identity map. The only difference is we don't need a body associated +// with our `scan` op. +// +// With this idea in mind, the TritonTilingExtInterface exposes methods +// that will be implemented individually by each TritonTilingExtOp, providing +// the indexing map for each input / output that can then be used to generate +// the correct slices during tiling and fusion. +// +// There might be other ops in the future that won't fit in this "indexing map" +// approach; we will consider making TritonTilingExtInterface an optional +// interface for such ops. +// + +def TritonTilingExtInterface : OpInterface<"TritonTilingExtInterface"> { + let cppNamespace = "::mlir::ttx"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the input operand with the given `index`. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getInputIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the output operand with the given `index`. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getOutputIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for the operand with the given `index`. + This method returns the operand in order of inputs followed by outputs. + The `tileSizes` input indicates the requested tile size during tiling + in case the indexing map for the operator is dependent on it. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getIndexingMap", + /*args=*/(ins "MLIRContext*":$context, + "unsigned int":$index, + "ArrayRef":$tileSizes) + > + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td new file mode 100644 index 000000000..d3a4268a5 --- /dev/null +++ b/third_party/tsingmicro/include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.td @@ -0,0 +1,242 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TILING_EXT_BASE +#define TRITON_TILING_EXT_BASE + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + +include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td" + + +//===----------------------------------------------------------------------===// +// TritonTilingExt dialect definition +//===----------------------------------------------------------------------===// + +def TritonTilingExt_Dialect : Dialect { + let name = "ttx"; + let cppNamespace = "::mlir::ttx"; +} + +//===----------------------------------------------------------------------===// +// TritonTilingExt op definitions +//===----------------------------------------------------------------------===// + +// Base class for TritonTilingExt dialect ops. +class TritonTilingExt_Op traits = []> + : Op { +} + +class TritonTilingExt_TilingOp : Op, + // All TritonTilingExtOps implement TritonTilingExtInterface, which provides a standardized + // way of providing indexing maps for input and output operands. + DeclareOpInterfaceMethods, + + // MemoryEffectsOpInterface provides analysis passes such as DCE to determine + // whether an operation has no memory side effects and therefore is safe to + // be deleted. This interface is important during tile and fuse where we + // create copies of TilingInterface ops with smaller tile sizes but leave the + // original ops intact. + DeclareOpInterfaceMethods, + + // DestinationStyleOpInterface describes ops that have similar semantics to + // linalg ops, with a separate ins (input) and outs (output) operand groups. + // Implementing this op gives us access to a wide variety of useful methods + // to query the inputs and outputs of an op. + DestinationStyleOpInterface, + + // AttrSizedOperandSegments supports having multiple groups of operands. + // For example, linalg ops (as well as TritonTilingExtOps) all look like this: + // ttx.some_op ins(%1) outs(%2) -> resultType + AttrSizedOperandSegments +]> +{ + let results = (outs Variadic:$result_tensors); + + let hasCustomAssemblyFormat = 1; + + code baseClassDecls = [{ + // Implemented as part of DestinationStyleOpInterface + MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + }]; + + // Custom print() and parse() methods to make the TritonTilingExt ops have similar looks + // to the linalg ops. + // Borrowed from llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp + let extraClassDefinition = [{ + void $cppClass::print(OpAsmPrinter &p) { + p.printOptionalAttrDict(this->getOperation()->getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes"}); + + if (!getInputs().empty()) + p << " ins(" << getInputs() << " : " << getInputs().getTypes() << ")"; + if (!getOutputs().empty()) + p << " outs(" << getOutputs() << " : " << getOutputs().getTypes() << ")"; + + if (!getResultTypes().empty()) + p.printOptionalArrowTypeList(getResultTypes()); + } + + ParseResult $cppClass::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector inputTypes; + SmallVector outputTypes; + SMLoc inputsOperandsLoc, outputsOperandsLoc; + SmallVector inputsOperands, + outputsOperands; + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputsOperands, outputTypes, + outputsOperandsLoc, result.operands)) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + + SmallVector resultTypes; + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + result.addTypes(resultTypes); + + return success(); + } + + AffineMap $cppClass::getIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index < this->getNumOperands()); + if (index < getNumDpsInputs()) { + return getInputIndexingMap(context, index, sizes); + } + return getOutputIndexingMap(context, index - getNumDpsInputs(), sizes); + } + + // Forward each of the implementation to the shared implementation + FailureOr $cppClass::getTiledImplementation( + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes + ) { + return mlir::ttx::getTiledImplementation<$cppClass>( + *this, b, offsets, sizes + ); + } + + // Forward each of the implementation to the shared implementation + LogicalResult $cppClass::getResultTilePosition( + OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes + ) { + return mlir::ttx::getResultTilePosition<$cppClass>( + *this, b, resultNumber, offsets, sizes, resultOffsets, resultSizes + ); + } + + // Forward each of the implementation to the shared implementation + FailureOr $cppClass::generateResultTileValue( + OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes + ) { + return mlir::ttx::generateResultTileValue<$cppClass>( + *this, b, resultNumber, offsets, sizes + ); + } + + // Implemented as part of MemoryEffectsOpInterface + void $cppClass::getEffects( + SmallVectorImpl> + &effects + ) { + return mlir::ttx::getEffects<$cppClass>(*this, effects); + } + }]; +} + +def TritonTilingExt_CumSumOp : TritonTilingExt_TilingOp<"cumsum"> { + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + UI32Attr:$axis + ); + + let hasVerifier = 1; + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<(ins + "Value":$input, + "IntegerAttr":$axis, + "Value":$output, + CArg<"ArrayRef", "{}">:$attributes + )> + ]; + + let extraClassDeclaration = baseClassDecls # [{ + int64_t getRank() { + return cast(getInput().getType()).getRank(); + } + + Value getInput() { + return getInputs()[0]; + } + + Value getOutput() { + return getOutputs()[0]; + } + + static StringRef getAxisAttrStrName() { return "axis"; } + }]; +} + +#endif // TRITON_TILING_EXT_BASE diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt new file mode 100644 index 000000000..923f5b7e7 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(MKToTx81) +add_subdirectory(Tx81ToLLVM) +add_subdirectory(Tx81MemrefToLLVM) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt new file mode 100644 index 000000000..a69d0ceb2 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name MKToTx81) +add_public_tablegen_target(MKToTx81ConversionPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h new file mode 100644 index 000000000..279d671e7 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h @@ -0,0 +1,37 @@ +//===------------------- MKToTx81.h ---------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering magic kernel ops to TsingMicro Tx81 target. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_MK_TO_TX81_H +#define ZTC_CONVERSION_MK_TO_TX81_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +void populateMKToTx81CanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateMKToTx81ConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr> createMKToTx81Pass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MK_TO_TX81_H \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h new file mode 100644 index 000000000..c9a8f51c0 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MK_TO_TX81_CONVERSION_PASSES_H +#define MK_TO_TX81_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // MK_TO_TX81_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td new file mode 100644 index 000000000..295fc05bd --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/Passes.td @@ -0,0 +1,18 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MK_TO_TX81_CONVERSION_PASSES +#define MK_TO_TX81_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def MKToTx81 : Pass<"mk-to-tx81", "mlir::ModuleOp"> { + let summary = "Convert magic kernel operations into TsingMicro Tx81 operations"; + let constructor = "triton::createMKToTx81Pass()"; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt new file mode 100644 index 000000000..fbc6e31df --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name Tx81MemrefToLLVM) +add_public_tablegen_target(Tx81MemrefToLLVMConversionPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h new file mode 100644 index 000000000..54079c976 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_TO_MK_CONVERSION_PASSES_H +#define MEMREF_TO_MK_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // MEMREF_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td new file mode 100644 index 000000000..b2a0d2c9b --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.td @@ -0,0 +1,19 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_TO_MK_CONVERSION_PASSES +#define MEMREF_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def Tx81MemrefToLLVM : Pass<"tx81-memref-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert memref and bufferization operations into custom llvm function call."; + let constructor = "triton::createTx81MemrefToLLVMPass()"; + let options = []; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h new file mode 100644 index 000000000..957c7e957 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h @@ -0,0 +1,40 @@ +//===------------------- Tx81MemrefToLLVM.h -------------------------*- C++ +//-*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering memref.copy, memref.alloc to mk.load, mk.alloc etc. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_MEMREF_TO_MK_H +#define ZTC_CONVERSION_MEMREF_TO_MK_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +void populateTx81MemrefToLLVMCanonicalizationPatterns( + RewritePatternSet &patterns); + +void populateTx81MemrefToLLVMConversionPatterns(RewritePatternSet &patterns, + LLVMTypeConverter &converter); + +std::unique_ptr> createTx81MemrefToLLVMPass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt new file mode 100644 index 000000000..f8257f56b --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name Tx81ToLLVM) +add_public_tablegen_target(Tx81ToLLVMConversionPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h new file mode 100644 index 000000000..d877e0287 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h @@ -0,0 +1,31 @@ +//===- KernelArgBufferPass.h ----------------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms kernel function signatures by converting multiple +// arguments into a single void* buffer containing all the arguments. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_KERNEL_ARG_BUFFER_PASS_H +#define MLIR_KERNEL_ARG_BUFFER_PASS_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +class ModuleOp; +class Pass; + +/// Creates a pass that transforms kernel functions by replacing multiple arguments +/// with a single void* buffer argument. +std::unique_ptr createKernelArgBufferPass(); + +#define GEN_PASS_DECL_KERNELARGBUFFERPASS +#include "KernelArgBufferPass.h.inc" +} // namespace mlir + +#endif // MLIR_KERNEL_ARG_BUFFER_PASS_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td new file mode 100644 index 000000000..0d63527ab --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td @@ -0,0 +1,32 @@ +//===- KernelArgBufferPass.td ---------------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_ARG_BUFFER_PASS +#define KERNEL_ARG_BUFFER_PASS + +include "mlir/Pass/PassBase.td" + +def KernelArgBufferPass : Pass<"kernel-arg-buffer", "ModuleOp"> { + let summary = "Convert kernel arguments to a single buffer argument"; + let description = [{ + This pass transforms kernel function signatures by converting multiple + arguments into a single void* buffer containing all the arguments. + + For example, a function like: + add_kernel(uint64_t* arg1, uint64_t* arg2, int64_t size, int gridX, int x) + + Will be converted to: + add_kernel(void* args) + + Where the args buffer contains pointers to arg1 and arg2, followed by the scalar + values size, gridX, and x. Each scalar value occupies 8 bytes in the buffer. + }]; + let constructor = "mlir::createKernelArgBufferPass()"; + let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::func::FuncDialect"]; +} + +#endif // KERNEL_ARG_BUFFER_PASS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h new file mode 100644 index 000000000..f0f0138b8 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TX81_TO_LLVM_CONVERSION_PASSES_H +#define TX81_TO_LLVM_CONVERSION_PASSES_H + +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TX81_TO_LLVM_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td new file mode 100644 index 000000000..2ed0159cc --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.td @@ -0,0 +1,38 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TX81_TO_LLVM_CONVERSION_PASSES +#define TX81_TO_LLVM_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + + +def Tx81ToLLVM : Pass<"tx81-to-llvm", "ModuleOp"> { + let summary = "Convert Tx81 dialect to LLVM dialect"; + let description = [{ + This pass converts operations in the Tx81 dialect to the LLVM IR dialect. + + It handles the conversion of Tx81-specific operations like tx.rdma, tx.wdma, + tx.gemm etc to appropriate LLVM calls to the Tx81 runtime library. + + The pass also relies on existing conversion patterns for standard dialects + like arith, func, memref, etc. + }]; + + let constructor = "triton::createTx81ToLLVMPass()"; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect", + "tx::Tx81Dialect" + ]; +} + +#endif diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h new file mode 100644 index 000000000..9d3ac7ffc --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h @@ -0,0 +1,33 @@ +//===------------------- Tx81ToLLVM.h -------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_TX81_TO_LLVM_H +#define TRITON_CONVERSION_TX81_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +void populateTx81ToLLVMConversionPatterns(RewritePatternSet &patterns, + ConversionTarget &target, + LLVMTypeConverter &converter); + +std::unique_ptr> createTx81ToLLVMPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TX81_TO_LLVM_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt new file mode 100644 index 000000000..6b74f8677 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS Tx81Ops.td) +mlir_tablegen(Tx81Dialect.h.inc -gen-dialect-decls -dialect=tx) +mlir_tablegen(Tx81Dialect.cpp.inc -gen-dialect-defs -dialect=tx) +mlir_tablegen(Tx81Ops.h.inc -gen-op-decls) +mlir_tablegen(Tx81Ops.cpp.inc -gen-op-defs) + +mlir_tablegen(Tx81Enums.h.inc -gen-enum-decls) +mlir_tablegen(Tx81Enums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS Tx81Types.td) +mlir_tablegen(Tx81Types.h.inc -gen-typedef-decls) +mlir_tablegen(Tx81Types.cpp.inc -gen-typedef-defs) + +add_public_tablegen_target(Tx81TableGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td new file mode 100644 index 000000000..c69d3540d --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td @@ -0,0 +1,24 @@ +//===---------------------- Tx81AttrDefs.td -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_ATTR_DEFS +#define TSINGMICRO_TX81_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Round mode, aligned with RND_MODE in instr_def.h +def RoundModeAttr : I32EnumAttr<"RoundMode", "Round mode", [ + I32EnumAttrCase<"RND_NEAREST_EVEN", 0, "nearest">, + I32EnumAttrCase<"RND_ZERO", 1, "zero">, + I32EnumAttrCase<"RND_POS_INF", 2, "pos">, + I32EnumAttrCase<"RND_NEG_INF", 3, "neg">, + I32EnumAttrCase<"RND_STOCHASTIC", 4, "stochastic"> +]> { + let cppNamespace = "::mlir::tx"; +} + +#endif // TSINGMICRO_TX81_ATTR_DEFS \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h new file mode 100644 index 000000000..2fd0c9f34 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h @@ -0,0 +1,33 @@ +//===-------------------------- Tx81Dialect.h -----------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H +#define MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// TsingMicro Tx81 Operations +//===----------------------------------------------------------------------===// +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" + +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td new file mode 100644 index 000000000..172d2a6ee --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.td @@ -0,0 +1,43 @@ +//===----------------------- Tx81Dialect.td -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_DIALECT +#define TSINGMICRO_TX81_DIALECT + +include "mlir/IR/OpBase.td" + +def Tx81Dialect : Dialect { + let name = "tx"; + + let cppNamespace = "::mlir::tx"; + + let summary = "The TsingMicro Tx81 IR in MLIR"; + + let description = [{ + TsingMicro Tx81 Dialect. + + Dependent Dialects: + * MK + * Memref + * Bufferization + }]; + + let dependentDialects = [ + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + // let hasConstantMaterializer = 1; + // let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "tsingmicro-tx81/Dialect/IR/Tx81Types.td" + +#endif // TSINGMICRO_TX81_DIALECT diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h new file mode 100644 index 000000000..ca9ba4bf9 --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h @@ -0,0 +1,26 @@ +//===-------------------------- Tx81Ops.h ---------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TSINGMICRO_TX81_IR_OPS_H +#define MLIR_DIALECT_TSINGMICRO_TX81_IR_OPS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" + +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td new file mode 100644 index 000000000..513baf79f --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -0,0 +1,775 @@ + +//===---------------------- Tx81Ops.td ------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Definition of TsingMicro's Tx81 ML accelerator operations. +// +// Data format supported by Tx81 ML accelerator are: +// f16,fp16,tf32,fp32 +// +// For Tx81 accelerator unsupported data type, we can either convert it by +// using `TsmConvert`, or lower the operations to run on RISC-V controller +// instead. +// +// NOTE: CHANGING THE ARGUMENTS AND RETURNS OF ANY OPS RESULT IN THE CHANGE OF +// THEIR RUNTIME INTERFACE AND IMPLEMENTATION IN crt/Target/Tx81. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_OPS +#define TSINGMICRO_TX81_OPS + +include "tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td" +include "tsingmicro-tx81/Dialect/IR/Tx81Types.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/IR/OpBase.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +class Tx81Op traits = []> : + Op { +} + +def MemRefOrInt + : AnyTypeOf<[AnyMemRef, AnySignlessIntegerOrIndex], + "MemRef or Int as address type.", "::mlir::Type">; + +// ============================================================================= +// 4.8/4.9 DDR and SPM transfer ops +// ============================================================================= + +def RdmaOp : Tx81Op<"rdma", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { + + let summary = "Copy data from global memory DDR(dram) to per thread local SPM(sram)"; + + let description = [{ + Copy data from global memory DDR(dram) to per thread local SPM(sram). + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$shape, // HHWC shape + Variadic:$strides, // 3 dim strides + I32Attr:$fmt + ); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$target, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + + let results = (outs I64:$dst); // The dest address in SPM +} + +def WdmaOp : Tx81Op<"wdma", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { + let summary = "Copy data from per thread local SPM(sram) to global memory DDR(dram)"; + + let description = [{ + Copy data from per thread local SPM(sram) to global memory DDR(dram). + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in DDR + MemRefOrInt:$target, // The target address in SPM + Variadic:$shape, // HHWC shape + Variadic:$strides, // 3 dim strides + I32Attr:$fmt + ); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$target, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + + let results = (outs I64:$dst); // The dest address in DDR +} + +// ============================================================================= +// 4.4~6 TsmConv, TsmDepthwiseConv, TsmBackwardConv +// ============================================================================= + +def ConvOp : Tx81Op<"conv", [Pure]> { + let summary = "Convolution engine intrinsic runtime API"; + + let description = [{ + A common convolution op for TsmConv, TsmDepthwiseConv, TsmBackwardConv. + This TsmConv is not a 1 to 1 map to TsingMicro's TsmConv intrinsic, it is + the wrap of all APIs related to TsmConv. This op wraps the following APIs: + TsmNewConv, TsmDeleteConv, AddInput, AddWeight, AddBias, AddOutput, + SetOpType, SetNegativeAxisScale, SetPositiveAxisScale, SetSparse, SetPsum, + SetPads, SetUnPads, SetKernelStrides, SetDilations, EnableRelu, + EnableLeakyRelu, DisableRelu, DisableLeakyRelu, SetQuant. + }]; + + let arguments = ( + ins + I64Attr:$op_type, // 0: conv, 1: depthwise conv, 2: backward conv, + // 3: gemm + MemRefOrInt:$src_activation, // Input activation addr in SPM + I32ArrayAttr:$src_dims, // dims of src activation in NHWC format + MemRefOrInt:$weight, // Input weight addr in SPM + I16Attr:$weight_dims, // dims of weight(conv kernel) in Kx, Ky, Sx, Sy + // Where K and S is short for size(K) and step(S) + BoolAttr:$en_bias, // Enable bias add + MemRefOrInt:$src_bias, // The address of bias in SPM + BoolAttr:$en_neg_scale, // Enable negative axis scale + MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM + BoolAttr:$en_pos_scale, // Enable positive axis scale + MemRefOrInt:$src_pos_scale, // The address of positive scale data in SPM + BoolAttr:$en_sparse, // Enable sparse + MemRefOrInt:$src_sparse, // The sparse matrix addr in SPM + BoolAttr:$en_psum, // Enable psum? TODO: Production sum? + MemRefOrInt:$src_psum, // psum addr in SPM? + I32ArrayAttr:$pads, // Pad in top, bottom, left, right order + I32ArrayAttr:$unpads, // Unpad in top, bottom, left, right order + I32ArrayAttr:$strides, // Kernel strids in Kx, Ky, Sx, Sy + I32ArrayAttr:$dilations, // dialation d0, d1 for conv/backwardconv + BoolAttr:$en_leaky_relu, // Enable LeakyRelu or normal Relu + I32ArrayAttr:$out_dims, // dims of output in NHWC format + I64Attr:$src_fmt, // Data format of src activation + I64Attr:$weight_fmt, // Data format of weight + I64Attr:$out_fmt // Data format of output + // The param of SetQuant() is unused + ); + + // Output matrix C addr in SPM + let results = (outs I64:$dst); +} + +// ============================================================================= +// 4.7. TsmGemm +// ============================================================================= + +def GemmOp : Tx81Op<"gemm", []> { + let summary = "Gemm engine intrinsic runtime API"; + + let description = [{ + This TsmGemm is not a 1 to 1 map to TsingMicro's TsmGemm intrinsic, it is + the wrap of all APIs related to TsmGemm. This op wraps the following APIs: + TsmNewGemm, TsmDeleteGemm, AddInput, ConfigMKN, AddOutput, SetPsum, + SetTransflag, SetQuant, ConfigBatch, EnableRelu, EnableLeakyRelu, + DisableRelu, DisableLeakyRelu, AddBias, SetNegativeAxisScale, + SetPositiveAxisScale. + }]; + + let arguments = ( + ins + MemRefOrInt:$src_a, // Input matrix A addr in SPM + MemRefOrInt:$src_b, // Input matrix B addr in SPM + MemRefOrInt:$src_bias, // The address of bias in SPM + // Zeroes buffer which can be used to fill $dst + // FIXME: Whether need add side effect to source operands? + Arg:$zeroes, + I32ArrayAttr:$dims, // The dimensions of M, K, N + BoolAttr:$en_psum, // Enable psum? TODO: Production sum? + MemRefOrInt:$psum_addr, // The address of psum in SPM, TODO: psum? + BoolAttr:$trans_src_a, // Should matrix A be transposed + BoolAttr:$trans_src_b, // Should matrix B be transposed + I32Attr:$batch_src_a, // The batch of matrix A + I32Attr:$batch_src_b, // The batch of matrix B + BoolAttr:$en_leaky_relu,// Enable LeakyRelu or normal Relu + BoolAttr:$en_bias, // Enable bias add + BoolAttr:$en_neg_scale, // Enable negative axis scale + MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM + BoolAttr:$en_pos_scale, // Enable positive axis scale + MemRefOrInt:$src_pos_scale, // The address of positive scale data in SPM + I32Attr:$src_fmt, // Input matrix data format + I32Attr:$dst_fmt // Output matrix data format + // The param of SetQuant() is unused + ); + + // Output matrix C addr in SPM + let results = (outs Variadic:$dst); +} + +// ============================================================================= +// 4.10. TsmArith +// ============================================================================= + +def AbsVVOp : Tx81Op<"absvv", [Pure, Elementwise]> {} +def RecipVVOp : Tx81Op<"recipvv", [Pure, Elementwise]> {} +def SquareVVOp : Tx81Op<"squarevv", [Pure, Elementwise]> {} + +class BinaryVVOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$rnd_mode, // round mode + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AddVVOp : BinaryVVOp<"addvv"> { + let summary = "Add two vectors element-wise"; +} +def SubVVOp : BinaryVVOp<"subvv">; +def MulVVOp : BinaryVVOp<"mulvv">; +def DivVVOp : BinaryVVOp<"divvv">; + +class BinaryVSOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + I32:$value, // Const value + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$rnd_mode, // round mode + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AddVSOp : BinaryVSOp<"addvs"> { + let summary = "Add input vector and constant value"; +} +def SubVSOp : BinaryVSOp<"subvs">; +def MulVSOp : BinaryVSOp<"mulvs">; +def DivVSOp : BinaryVSOp<"divvs">; + +// ... + +// ============================================================================= +// 4.13. TsmTranscendental +// ============================================================================= + +class TranscendentalOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$src, // Input vector address + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs I64:$dst); +} + +def Log2 : TranscendentalOp<"log2", []> { + let summary = "Logarithm based 2"; +} +def Ln : TranscendentalOp<"ln", []> { + let summary = "Logarithm based e"; +} +def Pow2 : TranscendentalOp<"pow2", []> { + let summary = "2 ** x"; +} +def Exp : TranscendentalOp<"exp", []> { + let summary = "Exponential with high precision"; +} +def Explp : TranscendentalOp<"explp", []> { + let summary = "Exponential with low precision"; +} +def Sin : TranscendentalOp<"sin", []> { + let summary = "Sine"; +} +def Cos : TranscendentalOp<"cos", []> { + let summary = "Cosine"; +} + +// ============================================================================= +// 4.13. TsmActivation +// ============================================================================= + +class ActivationOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$src, // Input vector address + UI32Attr:$elem_count, // Number of input elements + UI16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Tanh : ActivationOp<"tanh", []> { + let summary = "Hyperbolic tangent"; +} +def Sigmoid : ActivationOp<"sigmoid", []> { + let summary = "Logistic sigmoid"; +} +def Relu : ActivationOp<"relu", []> { + let summary = "Rectified linear unit"; +} +def Satrelu : ActivationOp<"satrelu", []> { + let summary = "Saturated ReLU"; +} +def Leakyrelu : ActivationOp<"leakyrelu", []> { + let summary = "Leaky rectified linear unit"; +} +def Softplus : ActivationOp<"softplus", []> { + let summary = "Smooth approximation of ReLU"; +} + +// ============================================================================= +// 4.15. TsmReduce +// ============================================================================= + +class Reduce : Tx81Op { + let summary = "Reduction engine intrinsic runtime API"; + + let description = [{ + Includes ReduceSum, ReduceAvg, ReduceMin and ReduceMax interfaces. + Mapping between `dim` and NCHW: + Reduction on C: dim=0 + Reduction on W: dim=1 + Reduction on H: dim=2 + Reduction on HW: dim=4 + }]; + + let arguments = ( + ins + AnyType:$src, // Input tensor address in SPM + Arg:$dst, // Output tensor address in SPM + UI32Attr:$dim, // Which dimension to be reduced + I64ArrayAttr:$shape, // The shape info of src + I16Attr:$fmt // The data format of src & dst + ); + + // Output tensor address in SPM + let results = (outs Variadic); +} + +def ReduceSumOp : Reduce<"reduce_sum">; +def ReduceAvgOp : Reduce<"reduce_avg">; +def ReduceMaxOp : Reduce<"reduce_max">; +def ReduceMinOp : Reduce<"reduce_min">; + +// ============================================================================= +// 4.15. TsmMaskDataMove +// ============================================================================= + +def MaskMoveOp : Tx81Op<"mask_move", []> { + let summary = "Mask data move engine intrinsic runtime API"; + + let description = [{ When mask is 1, extract the data from src and write it to dst. +When mask=0, the corresponding elements of dst remain unchanged. + }]; + + let arguments = ( + ins + MemRefOrInt:$source, // The source address in SPM + // The target address in SPM + Arg:$target, + AnySignlessIntegerOrIndex:$elem_count, // Number of elements to be copied + I32ArrayAttr:$mask, // 3 dim masks + I32Attr:$fmt + ); + + // The dst address is not used, use target in arguments instead. + let results = (outs Variadic:$dst); +} + +// ============================================================================= +// 4.19. TsmConvert instructions +// ============================================================================= + +class ZeroPointConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$src, + UI32Attr:$zp, + UI32Attr:$elem_count + ); + let results = (outs UI64:$dst); +} + +def INT8ToFP16Op : ZeroPointConvertOp<"int8_fp16", []> { + let summary = "Data format from int8 to fp16"; +} +def INT8ToBF16Op : ZeroPointConvertOp<"int8_bf16", []> { + let summary = "Data format from int8 to bf16"; +} +def INT8ToFP32Op : ZeroPointConvertOp<"int8_fp32", []> { + let summary = "Data format from int8 to fp32"; +} +def INT8ToTF32Op : ZeroPointConvertOp<"int8_tf32", []> { + let summary = "Data format from int8 to tf32"; +} + +class RoundConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, + Arg:$output, + AnySignlessIntegerOrIndex:$elem_count, + I16Attr:$rnd_mode + ); + let results = (outs I64:$dst); +} + +def INT16ToBF16Op : RoundConvertOp<"int16_bf16", []> { + let summary = "Data format from int16 to bf16"; +} +def INT16ToFP32Op : RoundConvertOp<"int16_fp32", []> { + let summary = "Data format from int16 to fp32"; +} +def INT16ToTF32Op : RoundConvertOp<"int16_tf32", []> { + let summary = "Data format from int16 to tf32"; +} +def INT32ToFP16Op : RoundConvertOp<"int32_fp16", []> { + let summary = "Data format from int32 to fp16"; +} +def INT32ToBF16Op : RoundConvertOp<"int32_bf16", []> { + let summary = "Data format from int32 to bf16"; +} +def INT32ToFP32Op : RoundConvertOp<"int32_fp32", []> { + let summary = "Data format from int32 to fp32"; +} +def INT32ToTF32Op : RoundConvertOp<"int32_tf32", []> { + let summary = "Data format from int32 to tf32"; +} +def BF16ToINT16Op : RoundConvertOp<"bf16_int16", []> { + let summary = "Data format from bf16 to int16"; +} +def BF16ToINT32Op : RoundConvertOp<"bf16_int32", []> { + let summary = "Data format from bf16 to int32"; +} +def FP16ToINT8Op : RoundConvertOp<"fp16_int8", []> { + let summary = "Data format from fp16 to int8"; +} +def FP16ToINT16Op : RoundConvertOp<"fp16_int16", []> { + let summary = "Data format from fp16 to int16"; +} +def FP16ToINT32Op : RoundConvertOp<"fp16_int32", []> { + let summary = "Data format from fp16 to int32"; +} +def FP16ToBF16Op : RoundConvertOp<"fp16_bf16", []> { + let summary = "Data format from fp16 to bf16"; +} +def FP32ToINT8Op : RoundConvertOp<"fp32_int8", []> { + let summary = "Data format from fp32 to int8"; +} +def FP32ToINT16Op : RoundConvertOp<"fp32_int16", []> { + let summary = "Data format from fp32 to int16"; +} +def FP32ToINT32Op : RoundConvertOp<"fp32_int32", []> { + let summary = "Data format from fp32 to int32"; +} +def FP32ToFP16Op : RoundConvertOp<"fp32_fp16", []> { + let summary = "Data format from fp32 to fp16"; +} +def FP32ToBF16Op : RoundConvertOp<"fp32_bf16", []> { + let summary = "Data format from fp32 to bf16"; +} +def FP32ToTF32Op : RoundConvertOp<"fp32_tf32", []> { + let summary = "Data format from fp32 to tf32"; +} +def TF32ToINT8Op : RoundConvertOp<"tf32_int8", []> { + let summary = "Data format from tf32 to int8"; +} +def TF32ToINT16Op : RoundConvertOp<"tf32_int16", []> { + let summary = "Data format from tf32 to int16"; +} +def TF32ToINT32Op : RoundConvertOp<"tf32_int32", []> { + let summary = "Data format from tf32 to int32"; +} +def TF32ToFP32Op : RoundConvertOp<"tf32_fp32", []> { + let summary = "Data format from tf32 to fp32"; +} + +class NormalConvertOp traits> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, + Arg:$output, + AnySignlessIntegerOrIndex:$elem_count + ); + let results = (outs I64:$dst); +} + +def INT16ToFP16Op : NormalConvertOp<"int16_fp16", []> { + let summary = "Data format from int16 to fp16"; +} +def BF16ToINT8Op : NormalConvertOp<"bf16_int8", []> { + let summary = "Data format from bf16 to int8"; +} +def BF16ToFP16Op : NormalConvertOp<"bf16_fp16", []> { + let summary = "Data format from bf16 to fp16"; +} +def BF16ToFP32Op : NormalConvertOp<"bf16_fp32", []> { + let summary = "Data format from bf16 to fp32"; +} +def BF16ToTF32Op : NormalConvertOp<"bf16_tf32", []> { + let summary = "Data format from bf16 to tf32"; +} +def FP16ToFP32Op : NormalConvertOp<"fp16_fp32", []> { + let summary = "Data format from fp16 to fp32"; +} +def FP16ToTF32Op : NormalConvertOp<"fp16_tf32", []> { + let summary = "Data format from fp16 to tf32"; +} +def TF32ToFP16Op : NormalConvertOp<"tf32_fp16", []> { + let summary = "Data format from tf32 to fp16"; +} +def TF32ToBF16Op : NormalConvertOp<"tf32_bf16", []> { + let summary = "Data format from tf32 to bf16"; +} + +// ============================================================================= +// 4.20. TsmPeripheral instructions +// ============================================================================= + +def CountOp : Tx81Op<"count", [Pure]> { + let summary = "Count the non-zero elements from given tensor"; + + let arguments = ( + ins + MemRefOrInt:$src, // Input tensor address in SPM + I32Attr:$elem_count, // TODO: Ask TsingMicro for explain. + //I64Attr:$p_wb_data0, // TODO: Ask TsingMicro for explain. + //I64Attr:$p_wb_data1, // TODO: Ask TsingMicro for explain. + I16Attr:$fmt + ); + + // The output tensor address in SPM + let results = (outs MemRefOrInt:$dst); +} + +def MemsetOp : Tx81Op<"memset", []> { + let summary = "Write given `value` to range of address on SPM(sram)"; + + let arguments = ( + ins + MemRefOrInt:$src, // SPM address to be memset + I32:$value, // Value to be written + AnySignlessIntegerOrIndex:$elem_count, + I32ArrayAttr:$strides, + I32ArrayAttr:$iterations, + I16Attr:$fmt + ); + + // The address updated by memset in SPM + let results = (outs MemRefOrInt:$dst); +} + +def Bit2FpOp : Tx81Op<"bit2fp", []> { + let summary = "Convert a vector of the bitwise into the fp vector"; + + let arguments = (ins + UI64:$src, // Input tensor + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def ArgMaxOp : Tx81Op<"argmax", []> { + let summary = "Return a max value inner a vector and its corresponding index"; + + let arguments = (ins + UI64:$src, // Input vector + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + + let results = (outs + AnyType:$max, // Max value inner a vector + UI64:$index // Corresponding index + ); +} + +def ArgMinOp : Tx81Op<"argmin", []> { + let summary = "Return a min value inner a vector and its corresponding index"; + + let arguments = (ins + UI64:$src, // Input vector + I32Attr:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + + let results = (outs + AnyType:$min, // Min value inner a vector + UI64:$index // Corresponding index + ); +} + +def BilinearOp : Tx81Op<"bilinear", []> { + let summary = "Bilinear interpolation"; + + let arguments = (ins + UI64:$src, // Input tensor with the NHWC format + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + F32:$scale_w, // Input tensor "w" divided by output tensor "w" + F32:$scale_h, // Input tensor "h" divided by output tensor "h" + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Lut16Op : Tx81Op<"lut16", []> { + let summary = "16-bit lookup table"; + + let arguments = (ins + // FIXME: AnyVector is not defined + // AnyVector:$src, // Vector offset with respect to LUT + UI64:$lut16, + I32Attr:$src_elem_count, // Number of elements in vector offset + I32Attr:$lut_elem_count // Number of elements in LUT + ); + + let results = (outs UI64:$dst); +} + +def Lut32Op : Tx81Op<"lut32", []> { + let summary = "32-bit lookup table"; + + let arguments = (ins + // FIXME: AnyVector is not defined + // AnyVector:$src, // Vector offset with respect to LUT + UI64:$lut32, + I32Attr:$src_elem_count, // Number of elements in vector offset + I32Attr:$lut_elem_count // Number of elements in LUT + ); + + let results = (outs UI64:$dst); +} + +def RandGenOp : Tx81Op<"randgen", []> { + let summary = "Generate random numbers using two 64-bit seeds"; + + let arguments = (ins + UI64:$src0, // The first random seed + UI64:$src1, // The second random seed + UI64:$dst0, // Store the first random seed + UI64:$dst1, // Store the second random seed + UI64:$dst2, // Random value + I32Attr:$elem_num, // Number of random values + I16Attr:$fmt // The date format of random value + ); +} + + +// +// 4.21. TsmDataMove +// + +class TransformOp traits = []> : + Tx81Op { + let arguments = (ins + UI64:$src, // Input matrix or tensor address + I32ArrayAttr:$src_shape, // Input shape + I32ArrayAttr:$dst_shape, // Output shape + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Mirror : TransformOp<"mirror", []> { + let summary = "Horizontal mirror to a matrix"; +} +def Transpose : TransformOp<"transpose", []> { + let summary = "Transpose a matrix"; +} +def Rotate90 : TransformOp<"rotate90", []> { + let summary = "Rotate a matrix 90 degree clockwise"; +} +def Rotate180 : TransformOp<"rotate180", []> { + let summary = "Rotate a matrix 180 degree clockwise"; +} +def Rotate270 : TransformOp<"rotate270", []> { + let summary = "Rotate a matrix 270 degree clockwise"; +} +def Nchw2nhwc : TransformOp<"nchw2nhwc", []> { + let summary = "Tranform a tensor from nchw to nhwc"; +} +def Nhwc2nchw : TransformOp<"nhwc2nchw", []> { + let summary = "Tranform a tensor from nhwc to nchw"; +} +def TensorNorm : TransformOp<"tensornorm", []> { + let summary = "Make continuous tensor align std format in ch direction"; +} + +def Concat : Tx81Op<"concat", []> { + let summary = "Concatenation based on the dim"; + + let arguments = (ins + UI64:$src1, // The first input tensor + I32ArrayAttr:$src1_shape, // The first input tensor shape + UI64:$src2, // The second input tensor + I32ArrayAttr:$src2_shape, // The second input tensor shape + I32ArrayAttr:$dst_shape, // Ouput tensor shape + I16Attr:$dim, // Represent the concat direction, such as: + // 0 is channel, 1 is width, and 2 is height + I16Attr:$fmt // The data format of input & output tensor + ); + let results = (outs UI64:$dst); +} + +def Pad : Tx81Op<"pad", []> { + let summary = "Tensor padding"; + + let arguments = (ins + UI64:$src, // Input tensor + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + I16Attr:$pad, // Padding mode: top, bottom, left, and right + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def Img2col : Tx81Op<"img2col", []> { + let summary = "Transform a feature-map tensor into a matrix"; + + let arguments = (ins + UI64:$src, // Input tensor + I32ArrayAttr:$src_shape, // Input tensor shape + I32ArrayAttr:$dst_shape, // Output tensor shape + I32Attr:$src_elem_num, // Number of elements in input tensor + I32Attr:$dst_elem_num, // Number of elements in output tensor + I32ArrayAttr:$swr, // Horizontal stride of convolution + I32ArrayAttr:$pdr, // Vertical stride of convolution + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs UI64:$dst); +} + +def GatherScatter : Tx81Op<"gatherscatter", []> { + let summary = "Transfer data in strides and iterations"; + + let arguments = (ins + UI64:$src, // Input tensor + I32Attr:$size, // Transfer data size in bytes + I32ArrayAttr:$src_strides, // 3 dim strides for input + I32ArrayAttr:$src_iterations, // 3 dim iterations for input + I32ArrayAttr:$dst_strides, // 3 dim strides for output + I32ArrayAttr:$dst_iterations // 3 dim iterations for output + ); + let results = (outs UI64:$dst); +} + +#endif // TSINGMICRO_TX81_OPS \ No newline at end of file diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td new file mode 100644 index 000000000..e4c98254e --- /dev/null +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td @@ -0,0 +1,107 @@ +//===-------------------------- Tx81Types.td ------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// TODO: Update this file to define the customized type used by Tx81 dialect, +// it is now copy-and-pasted from MagicKernelTypes.td. +// +//===----------------------------------------------------------------------===// + +#ifndef TSINGMICRO_TX81_TYPES_TD +#define TSINGMICRO_TX81_TYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.td" + +// +// Types +// +class MKTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def MKFloat : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def MKFloatTensor : RankedTensorOf<[MKFloat]>; +def MKFloatLike : AnyTypeOf<[MKFloat, MKFloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def MKBoolTensor : RankedTensorOf<[I1]>; +def MKBoolLike : AnyTypeOf<[I1, MKBoolTensor]>; + +// Integer Type +def I4 : I<4>; +def MKInt : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def MKIntTensor : RankedTensorOf<[MKInt]>; +def MKIntLike : AnyTypeOf<[MKInt, MKIntTensor]>; + +// I32 Type +// MKI32 -> I32 +// MKI32Tensor -> I32Tensor +def MKI32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// MKI64 -> I64 +// MKI64Tensor -> I64Tensor +def MKI64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class MKPtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `MKPtrOf`) +def MKPtrType : MKTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def MKPtr : MKPtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def MKPtrTensor : RankedTensorOf<[MKPtr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def MKPtrLike : AnyTypeOf<[MKPtr, MKPtrTensor]>; + +// Tensor Type +def MKFpIntTensor : RankedTensorOf<[MKFloat, MKInt]>; +def MKTensor : RankedTensorOf<[MKFloat, MKInt, MKPtr]>; + +// Pointer Type to Tensor Type: `ptr>` +def MKTensorPtr : MKPtrOf<[MKTensor]>; + +// Any Type in Magic Kernel IR +def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; + +#endif // TSINGMICRO_TX81_TYPES_TD \ No newline at end of file diff --git a/third_party/tsingmicro/lib/Analysis/CMakeLists.txt b/third_party/tsingmicro/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..643c6834f --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(ZTCAnalysis + MaskAnalysis.cpp + OpFoldResultUtils.cpp + PtrAnalysis.cpp + UseAnalysis.cpp + + DEPENDS + TritonAnalysis + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis +) diff --git a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp new file mode 100644 index 000000000..dc8a27c45 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp @@ -0,0 +1,559 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include + +namespace mlir { + +namespace triton { + +LogicalResult MaskState::parse(Value operand, const Location loc, + OpBuilder &builder) { + if (auto op = operand.getDefiningOp()) { + return this->parseConstant(op, loc, builder); + } else if (isa(operand.getType())) { + return this->parseIntScalar(operand, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAdd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAnd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseCmp(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseMakeRange(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseBroadcast(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseSplat(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseExpandDims(op, loc, builder); + } else if (!operand.getDefiningOp()) { + return this->parseLoopIterArg(operand, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseExtSI(op, loc, builder); + } else { + return failure(); + } +} + +tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + const Location loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, + dims, strides); + + return builder.create(loc, dstType, source, offsets, + dims, strides); +} + +memref::SubViewOp MaskState::getSubview(Value source, const Location loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + + return builder.create(loc, cast(dstType), + source, offsets, dims, strides); +} + +static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return b.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +// Assume block1 wraps around and the remainder is block2. +// +// |----------------------| +// | | | +// | block2 | block1 | +// | | | +// |----------------------| +// +// Once we copy the chunks in order, the end result is block1 followed by +// block2. +// +// buffer_tmp: +// +// |----------------------| +// | | | +// | block1 | block2 | +// | | | +// |----------------------| +// +// Assume we have the following subview: +// +// +++++++++++++++++------- +// + + | +// + subview + | +// + + | +// +++++++++++++++++------- +// +// If we simply take the subview of `buffer_tmp`, this requires an extra +// buffer to just hold the temporary result. +// +// So we can subview into block1 and block2 directly. There are 2 cases: +// + subview only spans block1 +// + subview spans both block1 and block2, creating sv1 and sv2 (illustrated +// below for case when we wrap around side-by-side) +// +// |----------------------------------------| +// | | +// | col2 col1 | +// |++++++--------| |+++++++++++++++ +// | sv2 + block2 | | block1 & sv1 + +// |++++++--------| |+++++++++++++++ +// | | +// |----------------------------------------| +// +// For simplicity, assume we only wrap around side-by-side. +// +// Let (row, col1) and (row, col2) be the dimensions of block1 and block2, +// respectively. +// +// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be +// the dimensions of the full subview, sv1, and sv2, respectively. +// +// + colView1 = min(colFull, col1) +// + colView2 = colFull - colView1 +// + rowView1 = rowView2 = row = rowFull +std::pair +MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder); + OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder); + + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, + {subviewRowFull, subviewCol1}, strides); + auto sv2 = createSubview(block2, loc, builder, offsets, + {subviewRowFull, subviewCol2}, strides); + + return {sv1, sv2}; +} + +std::pair +MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, + OpBuilder &builder) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder); + OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder); + + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, + {subviewRow1, subviewColFull}, strides); + auto sv2 = createSubview(block2, loc, builder, offsets, + {subviewRow2, subviewColFull}, strides); + return {sv1, sv2}; +} + +LogicalResult MaskState::addStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder) { + start = addOFRs(state.start, scalar, loc, builder); + end = addOFRs(state.end, scalar, loc, builder); + dims = state.dims; + return success(); +} + +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + else + return addStateScalar(lhsState, rhsState.scalar, loc, builder); +} + +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder)); + } + return success(); +} + +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, + const Location loc, OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "All elements must share a single integer constant value"); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto op = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + this->scalar = op.getValue(); + } else { + auto value = cast(constOp.getValue()).getInt(); + this->scalar = builder.getIndexAttr(value); + } + + return success(); +} + +LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto castOp = + builder.create(loc, builder.getIndexType(), scalar); + this->scalar = castOp.getResult(); + return success(); +} + +void MaskState::dump() const { + llvm::dbgs() << "start: " << start << "\n"; + llvm::dbgs() << "end: " << end << "\n"; + llvm::dbgs() << "scalar: " << scalar << "\n"; + llvm::dbgs() << "useUnsafeMask: " << useUnsafeMask << "\n"; + llvm::dbgs() << "dims: "; + for (auto dim : dims) + llvm::dbgs() << "\t" << dim << "\n"; + llvm::dbgs() << "\n"; +} + +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) + return failure(); + + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) + return failure(); + + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseExtSI(arith::ExtSIOp op, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + return parse(op.getIn(), loc, builder); +} + +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt && + cmpOp.getPredicate() != arith::CmpIPredicate::ult && + cmpOp.getPredicate() != arith::CmpIPredicate::sge) { + InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi"; + return failure(); + } + + MaskState lhsState; + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) + return failure(); + + // We only support sge against 0 for lower bounds. Dims already has an + // implicit assumption that the lower bound is 0, so if we see this, assume + // the comparison evaluates to true. + if (cmpOp.getPredicate() == arith::CmpIPredicate::sge + && !(rhsState.scalar && hasConstZero(rhsState.scalar))) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with rhs not equal to 0"; + return failure(); + } + + int32_t cmpDim = lhsState.scalar && rhsState.scalar ? 0 : -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto dimIntAttr = getIntAttr(lhsState.dims[i]); + if (!dimIntAttr || dimIntAttr.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + OpFoldResult newDim; + if (lhsState.scalar) { + assert(rhsState.scalar && "Unexpected case where rhs is not a scalar"); + // If both lhs and rhs are scalars, we can't just derive the dimension of + // the mask as the minimum value: lhs/rhs could be 0 and then we don't + // load/store anything. + // + // Instead treat the comparison as a scalar that determines if anything + // should be loaded/stored by inserting a comparison + select: + // dim = lhs < rhs ? lhs.dim : 0 + newDim = compareOFRs(lhsState.scalar, rhsState.scalar, cmpOp.getPredicate(), + lhsState.dims[cmpDim], builder.getIndexAttr(0), + loc, builder); + } else if (cmpOp.getPredicate() == arith::CmpIPredicate::slt || + cmpOp.getPredicate() == arith::CmpIPredicate::ult) { + // Important: + // In the case where the values we are loading are entirely masked off like + // the following: + // + // ---|-------|-----------| + // ^ ^ ^ + // scalar start end + // + // newEnd = min(end, scalar) = scalar + // Now scalar < start, so simply doing dim = newEnd - start is incorrect. + // + // The correct formula is to optionally move `newDim` back to `start` using + // max(newEnd, start). + auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); + newEnd = maxOFRs(newEnd, lhsState.start, loc, builder); + newDim = subOFRs(newEnd, lhsState.start, loc, builder); + } else { + assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && rhsState.scalar + && hasConstZero(rhsState.scalar)); + newDim = lhsState.dims[cmpDim]; + } + + for (int32_t i = 0; i < lhsState.getRank(); i++) { + if (i == cmpDim) + this->dims.push_back(newDim); + else + this->dims.push_back(lhsState.dims[i]); + } + + return success(); +} + +LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder) { + assert(!v.getDefiningOp()); + + auto forOp = llvm::dyn_cast(v.getParentRegion()->getParentOp()); + + if (!forOp) { + return failure(); + } + + // TODO: This implementation does not work with nested loops + if (forOp->getParentOfType()) { + return failure(); + } + + auto it = llvm::find(forOp.getRegionIterArgs(), v); + if (it == forOp.getRegionIterArgs().end()) { + return failure(); + } + + auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it); + auto initArg = forOp.getInitArgs()[argIndex]; + if (auto getStateOp = initArg.getDefiningOp()) { + auto tritonValue = getStateOp->getOperand(0); + MaskState lhsState; + if (failed(lhsState.parse(tritonValue, loc, builder))) { + return failure(); + } + + // This is a bit of a hack!! + // + // The offsets and dimensions of a MaskState can now depend on a loop's + // iter-arg. + // + // Because the PtrAnalysis's pre-pass already sets up the offsets, + // we can create a new MaskState for each loop iteration by adding the + // original MaskState with the current iter-arg, which is at `argIndex + + // 1`. + // + // This will not work for nested loop scenarios, which would need a + // more robust implementation. + if (failed(this->addStateScalar( + lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) { + return failure(); + } + + return success(); + } + + return failure(); +} + +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitError(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + + return success(); +} + +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) + return failure(); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + this->dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } + + return success(); +} + +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (!isa(src.getType())) { + InFlightDiagnostic diag = + emitError(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + for (auto s : dstShape) + this->dims.push_back(builder.getIndexAttr(s)); + + return success(); +} + +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) + return failure(); + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp new file mode 100644 index 000000000..62aa57ff9 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp @@ -0,0 +1,292 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +std::optional getIntAttr(const OpFoldResult ofr) { + // Check if ofr is an Attribute + if (auto attr = dyn_cast(ofr)) { + // Check if it's specifically an IntegerAttr + if (auto intAttr = dyn_cast(attr)) { + return intAttr.getInt(); + } + } + + return std::nullopt; +} + +bool hasConstZero(const OpFoldResult ofr) { + auto intAttr = getIntAttr(ofr); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + auto val = dyn_cast(ofr); + assert(val); + auto constOp = val.getDefiningOp(); + if (!constOp) + return false; + + intAttr = getIntAttr(constOp.getValue()); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + return false; +} + +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, + OpBuilder &b) { + if (Value val = dyn_cast(ofr)) { + assert(val.getType().isIntOrIndex()); + if (!val.getType().isIndex()) { + val = b.create(loc, b.getIndexType(), val); + } + return val; + } + + auto intVal = getIntAttr(ofr); + if (intVal.has_value()) { + return b.create(loc, b.getIndexAttr(intVal.value())); + } + llvm_unreachable("Unexpected OpFoldResult state"); + return nullptr; +} + +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b) { + return llvm::to_vector<4>( + llvm::map_range(ofrs, [&](OpFoldResult ofr) -> Value { + return ofrToIndexValue(ofr, loc, b); + })); +} + +OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + if (!rhsIntAttr && lhsIntAttr && lhsIntAttr.value() == 0) + return rhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(lhsIntAttr.value() + rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(lhsIntAttr.value() - rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto sumOp = b.create(loc, lhsValue, rhsValue); + return sumOp.getResult(); +} + +OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + + auto rhsIsConst = false; + // if rhs is not a const, use max value since min is used to represent + // dynamic size or stride + auto rhsConstValue = std::numeric_limits::max(); + auto rhsOp = rhs.getDefiningOp(); + if (rhsOp) { + rhsIsConst = true; + rhsConstValue = cast(rhsOp.getValue()).getInt(); + } + + // shortcuts for special cases + if (lhsIntAttr) { + if (lhsIntAttr.value() == 0) + return lhs; + if (lhsIntAttr.value() == 1) + return rhs; + } + if (rhsIsConst) { + if (rhsConstValue == 0) + return rhsOp.getResult(); + if (rhsConstValue == 1) + return lhs; + } + + // 0. both lhs and rhs are constants + if (lhsIntAttr && rhsIsConst) + return b.getIndexAttr(lhsIntAttr.value() * rhsConstValue); + + // 1. if lhs is constant but rhs is not + if (lhsIntAttr && !rhsIsConst) { + auto lhsConstOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + auto mulOp = b.create(loc, lhsConstOp.getResult(), rhs); + return mulOp.getResult(); + } + + // 2. if lhs is not constant + assert(!lhsIntAttr); + auto mulOp = b.create(loc, cast(lhs), rhs); + return mulOp.getResult(); +} + +OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(std::min(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto minOp = b.create(loc, lhsValue, rhsValue); + return minOp.getResult(); +} + +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(std::max(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto maxOp = b.create(loc, lhsValue, rhsValue); + return maxOp.getResult(); +} + +OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const arith::CmpIPredicate pred, const OpFoldResult trueOFR, + const OpFoldResult falseOFR, const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return the result directly + if (lhsIntAttr && rhsIntAttr) { + switch (pred) { + case arith::CmpIPredicate::eq: + return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::ne: + return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR; + default: + llvm_unreachable("Unsupported predicate"); + } + } + + auto lhsValue = ofrToIndexValue(lhs, loc, b); + auto rhsValue = ofrToIndexValue(rhs, loc, b); + auto trueValue = ofrToIndexValue(trueOFR, loc, b); + auto falseValue = ofrToIndexValue(falseOFR, loc, b); + + auto cmpOp = b.create(loc, pred, lhsValue, rhsValue); + auto selectOp = b.create(loc, cmpOp, trueValue, falseValue); + return selectOp.getResult(); +} +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp new file mode 100644 index 000000000..00715a9d3 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/PtrAnalysis.cpp @@ -0,0 +1,1375 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/PtrAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "triton-ptr-analysis" + +namespace mlir { + +namespace triton { + +static void assertValidUnrealizedCast(UnrealizedConversionCastOp op) { + assert(op && op->hasAttr(ModuloState::WraparoundAttr) && + op.getInputs().size() == 3 && + op.getInputs()[0].getDefiningOp() && + op.getInputs()[1].getDefiningOp() && + op.getInputs()[2].getDefiningOp()); +} + +MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape, + bool useDynamicStrides) const { + + SmallVector staticStrides; + if (useDynamicStrides) { + staticStrides.append(strides.size(), ShapedType::kDynamic); + } else { + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + } + + auto elementType = cast(source.getType()).getElementType(); + auto layout = + StridedLayoutAttr::get(source.getContext(), offset, staticStrides); + + return MemRefType::get(resultShape, elementType, layout); +} + +OpFoldResult +PtrState::accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult targetOffset = rewriter.getIndexAttr(0); + for (auto o : offsets) { + targetOffset = addOFRs(targetOffset, o, loc, rewriter); + } + return targetOffset; +} + +int64_t PtrState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + modulos.size() == offsets.size()); + return offsets.size(); +} + +bool PtrState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +bool PtrState::hasModulo() const { + return llvm::any_of(modulos, [](auto mod) { return mod.has_value(); }); +} + +void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, + Location loc, ConversionPatternRewriter &rewriter) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + // at most one of lhs and rhs should have valid source, since otherwise we + // will be losing information + assert(!(lhsState.source && rhsState.source)); + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + rewriter.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.sizes.size(); i++) { + auto newOffset = + addOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, rewriter); + offsets.push_back(newOffset); + + auto newStride = + addOFRs(lhsState.strides[i], rhsState.strides[i], loc, rewriter); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + + assert(!lhsState.hasModulo() || + !rhsState.hasModulo() && "AddPtr where both lhs and rhs containing " + "modulo operators not supported"); + + modulos.push_back(lhsState.modulos[i].has_value() ? lhsState.modulos[i] + : rhsState.modulos[i]); + } +} + +void PtrState::mulState(const PtrState &lhsState, const PtrState &rhsState, + const Location loc, + ConversionPatternRewriter &rewriter) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + assert(!(lhsState.source && rhsState.source)); + + assert((lhsState.scalar || rhsState.scalar) && + !(lhsState.scalar && rhsState.scalar) && + "currently does not support both tensors are effectively non-scalar"); + + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, rewriter); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, rewriter); + offsets.push_back(newOffset); + strides.push_back(newStride); + sizes.push_back(lhs->sizes[i]); + } + + assert(llvm::all_of(rhsState.modulos, + [](auto rhs) { return !rhs.has_value(); })); + + modulos = lhs->modulos; +} + +SmallVector +PtrState::createStackedCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2); + assert(modulos[0].has_value() && !modulos[1].has_value()); + + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // + // We do not support cases where the target offset has already overflown the + // number of rows. See side-by-side wraparound for details. + // + ////////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + ShapedType::kDynamic, // Row is dynamic, in most cases, this should be + // the same as the original row. The last chunk + // may be smaller due to wrapping around. + resultShape[1], // Col stays the same. + }, + true /*useDynamicStrides*/); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value strideRow = ofrToIndexValue(strides[0], loc, rewriter); + Value strideCol = ofrToIndexValue(strides[1], loc, rewriter); + + Value modRow = rewriter.create( + loc, rewriter.getIndexType(), modulos[0]->size); + + // First chunk + Value wrappedAroundOff = + rewriter.create(loc, targetOffset, strideRow); + Value clampedOff = rewriter.create(loc, modRow, strideRow); + clampedOff = + rewriter.create(loc, clampedOff, wrappedAroundOff); + Value d1 = rewriter.create(loc, clampedOff, targetOffset); + d1 = rewriter.create(loc, d1, strideRow); + + SmallVector sizes1{d1, colSize}; + memref::ReinterpretCastOp cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, + ValueRange{strideRow, strideCol}); + + // Second chunk + Value d2 = rewriter.create(loc, rowSize, d1); + SmallVector sizes2{d2, colSize}; + memref::ReinterpretCastOp cast2 = rewriter.create( + loc, resultType, source, wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); + + return {cast1, cast2}; +} + +SmallVector +PtrState::createSideBySideCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2 && !modulos[0].has_value() && modulos[1].has_value()); + + // Accumulate final offset + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Note: We do not support cases where the target has already overflown the + // number of columns! This is because in PtrAnalysis, the offset has already + // been collapsed into a single dimension, so it is ambiguous to determine + // whether the offset actually overflows or just refers to an element on the + // subsequent rows. + // + // Same limitations apply to the stacked wraparound case. + // + ////////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // nextOffset = x + colSize + // clampedOffset = min(nextOffset, N) + // d1 = clampedOffset - x + // + ////////////////////////////////////////////////////////////////////////////// + + SmallVector casts; + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + resultShape[0], // Row stays the same + ShapedType::kDynamic // Column is dynamic, in most cases, this should + // be the same as the original column. The last + // chunk may be smaller due to wrapping around. + }, + true /*useDynamicStrides*/); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value modN = rewriter.create(loc, rewriter.getIndexType(), + modulos[1]->size); + + Value x = rewriter.create(loc, targetOffset, modN); + Value y = rewriter.create(loc, targetOffset, x); + + SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); + + // First chunk + Value nextOffset = rewriter.create(loc, x, colSize); + Value clampedOffset = rewriter.create(loc, nextOffset, modN); + Value d1 = rewriter.create(loc, clampedOffset, x); + SmallVector sizes1{rowSize, d1}; + + auto cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, strideVals); + + // Second chunk + Value d2 = rewriter.create(loc, colSize, d1); + SmallVector sizes2{rowSize, d2}; + + auto cast2 = rewriter.create( + loc, resultType, source, y, sizes2, strideVals); + + return {cast1, cast2}; +} + +memref::ReinterpretCastOp +PtrState::createCastOp(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const { + // Accumulate final offset + OpFoldResult targetOffset = accumulateTargetOffset(loc, rewriter); + + // Create result MemRefType + SmallVector staticOffset; + SmallVector dynamicOffset; + dispatchIndexOpFoldResult(targetOffset, dynamicOffset, staticOffset); + + auto resultType = + getResultMemrefType(rewriter.getContext(), staticOffset[0], resultShape); + + // Create reinterpret cast + return rewriter.create( + loc, resultType, source, targetOffset, sizes, strides); +} + +void PtrAnalysis::visitOperandAdd( + arith::AddIOp addOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(addOp.getLhs(), lhsState, loc, rewriter, knownPtrs); + + PtrState rhsState; + visitOperand(addOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + assert(0 && "Current do not support this pattern: a + arange(0, K) % M"); + } + + state.addState(lhsState, rhsState, loc, rewriter); +} + +void PtrAnalysis::visitOperandMul( + arith::MulIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(mulOp.getLhs(), lhsState, loc, rewriter, knownPtrs); + + PtrState rhsState; + visitOperand(mulOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + + state.mulState(lhsState, rhsState, loc, rewriter); +} + +void PtrAnalysis::visitOperandRem( + arith::RemSIOp remOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState rhsState; + visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + assert(rhsState.scalar); + + visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs); + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + assert(llvm::all_of(state.modulos, + [](auto modState) { return !modState.has_value(); }) && + "No support for multiple modulo within an expression"); + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.modulos.back() = ModuloState{rhsState.scalar}; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.modulos[1] = ModuloState{rhsState.scalar}; + } else if (shape[1] == 1) { + state.modulos[0] = ModuloState{rhsState.scalar}; + } else { + assert(false && "Taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + } + } else { + assert(false && "Unsupported modulo pattern"); + } +} + +void PtrAnalysis::visitOperandMakeRange( + triton::MakeRangeOp rangeOp, PtrState &state, Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back(rewriter.getIndexAttr(start)); + state.sizes.push_back(rewriter.getIndexAttr(shape[0])); + state.strides.push_back(rewriter.getIndexAttr(stride)); + state.modulos.push_back(std::nullopt); +} + +void PtrAnalysis::visitOperandExpandDims( + triton::ExpandDimsOp expandDimsOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + // `getSrc` now returns a TypedValue of RankedTensorType. We modify these + // operands in-place and turn them into memrefs in loops, so we have to bypass + // the cast by using getSrcMutable. These are temporary fix only since + // we will be moving over to StructuredPtrAnalysis soon which separate out the + // memref conversion. + visitOperand(expandDimsOp.getSrcMutable().get(), state, loc, rewriter, + knownPtrs); + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, rewriter.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, rewriter.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, rewriter.getIndexAttr(0)); + state.modulos.insert(state.modulos.begin() + axis, std::nullopt); +} + +void PtrAnalysis::visitOperandBroadcast( + triton::BroadcastOp broadcastOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + // `getSrc` now returns a TypedValue of RankedTensorType. We modify these + // operands in-place and turn them into memrefs in loops, so we have to bypass + // the cast by using getSrcMutable. These are temporary fix only since + // we will be moving over to StructuredPtrAnalysis soon which separate out the + // memref conversion. + auto src = broadcastOp.getSrcMutable().get(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + visitOperand(src, state, loc, rewriter, knownPtrs); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + state.sizes[i] = rewriter.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } +} + +void PtrAnalysis::visitOperandSplat( + triton::SplatOp splatOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + visitOperand(src, state, loc, rewriter, knownPtrs); + + if (isa(src.getType())) { + for (auto s : dstShape) { + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(s)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } + } else { + // src is a memref that represent a scalar pointer; it should have + // one dimension of size 1. This happens inside a for loop that + // originally has an init arg that is a tensor of pointers; this arg + // would have been replaced by rewriteForOp. + auto srcType = cast(src.getType()); + assert(srcType.getRank() == 1 && state.getRank() == 1 && + "splat MemRef source should have rank 1"); + assert(srcType.getShape()[0] == 1 && + getIntAttr(state.sizes[0]).value() == 1 && + "splat MemRef source should have size 1"); + + // Stride[0] will have value of 1 set in visitOperandAddPtr. This + // value will be represented by a constOp. Clear this value. + state.strides[0] = rewriter.getIndexAttr(0); + + for (auto [i, s] : llvm::enumerate(dstShape)) { + if (i == 0) { + state.sizes[i] = rewriter.getIndexAttr(s); + continue; + } + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(s)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) + state.offsets[0] = state.scalar; +} + +void PtrAnalysis::visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto remappedValue = rewriter.getRemappedValue(makeTensorPtrOp); + if (auto castOp = remappedValue.getDefiningOp()) { + visitOperandReintCast(castOp, state, loc, rewriter, knownPtrs); + } else { + llvm_unreachable("Expect value to me mapped to a memref.reinterpret_cast"); + } +} + +void PtrAnalysis::visitOperandAddptr( + triton::AddPtrOp addptrOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState ptrState; + visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), rewriter, + knownPtrs); + + PtrState offsetState; + visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), rewriter, + knownPtrs); + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + // Handle the special case when we are in a for loop, ptr is originally a + // scalar pointer but replaced with a memref. In this case, ptrState will have + // rank 1 and offsetState will have rank 0. + // TODO: + // Passing a block argument pointer directly into a for loop not supported + if (ptrState.getRank() == 1 && offsetState.getRank() == 0) { + offsetState.sizes.push_back(rewriter.getIndexAttr(1)); + offsetState.offsets.push_back(offsetState.scalar); + offsetState.strides.push_back(rewriter.getIndexAttr(0)); + offsetState.modulos.push_back(std::nullopt); + } + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + state.addState(ptrState, offsetState, addptrOp.getLoc(), rewriter); +} + +void PtrAnalysis::visitOperandReintCast( + memref::ReinterpretCastOp reintCastOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + state.offsets = reintCastOp.getMixedOffsets(); + state.sizes = reintCastOp.getMixedSizes(); + state.strides = reintCastOp.getMixedStrides(); + state.source = reintCastOp.getSource(); + state.modulos.append(state.sizes.size(), std::nullopt); + + // getMixedOffsets produces staticOffsets (which is the result of collapsing + // multiple dimensions). Populate the rest of the dimensions with zeroes. + assert(state.offsets.size() == 1); + for (size_t i = 1; i < state.sizes.size(); i++) { + state.offsets.push_back(rewriter.getIndexAttr(0)); + } + + // Regular Triton programs cannot express patterns of size 1 and non-zero + // stride; we only set it that way to make memrefs work. Set stride back to + // zero if this scenario detected. + for (size_t i = 0; i < state.strides.size(); i++) { + auto strideIntAttr = getIntAttr(state.strides[i]); + auto sizeIntAttr = getIntAttr(state.sizes[i]); + + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr) { + state.strides[i] = rewriter.getIndexAttr(0); + } + } +} + +void PtrAnalysis::visitOperand( + Value operand, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return; + } + + if (isa(operand.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), operand); + state.scalar = castOp.getResult(); + return; + } + + if (isa(operand.getType())) { + auto remappedPtr = rewriter.getRemappedValue(operand); + assert(remappedPtr); + + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + visitOperandAddptr(cast(op), state, loc, rewriter, + knownPtrs); + } else if (auto makeTensorOp = dyn_cast(op)) { + visitOperandMakeTensorPtr(makeTensorOp, state, loc, rewriter, + knownPtrs); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } + } else { + state.source = remappedPtr; + } + return; + } + + if (auto op = operand.getDefiningOp()) { + visitOperandAdd(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandMul(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandMakeRange(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandBroadcast(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandSplat(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandExpandDims(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandAddptr(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandConstSplat(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandRem(op, state, loc, rewriter, knownPtrs); + } else { + operand.dump(); + llvm_unreachable("encountered addptr operand produced by an " + "unsupported operation"); + } +} + +void PtrAnalysis::visitOperandConstSplat( + arith::ConstantOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType)); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = rewriter.getIndexAttr(value.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(rewriter, constAttr, + rewriter.getIndexType(), loc); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (size_t i = 0; i < resultType.getShape().size(); i++) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(rewriter.getIndexAttr(0)); + } + + state.sizes.push_back(rewriter.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); + } +} + +void PtrAnalysis::rewriteAddptrOp( + triton::AddPtrOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs) { + // any inserted instruction should be before this addptr + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + PtrState state; + visitOperandAddptr(op, state, op.getLoc(), rewriter, knownPtrs); + + // If the result is a scalar pointer, visitOperandAddptr will not populate + // sizes, strides, and offsets. We need to do it here. + if (state.sizes.size() == 0) { + state.sizes.push_back(rewriter.getIndexAttr(1)); + state.strides.push_back(rewriter.getIndexAttr(0)); + state.offsets.push_back(state.scalar); + state.modulos.push_back(std::nullopt); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + if (auto shapedType = dyn_cast(op.getResult().getType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(state.getRank() == 1); + } + + knownPtrs[op.getResult()] = state; + + // If there are dimensions with size 1 and stride 0, replace 0 stride with the + // product of sizes of all lower dimensions. This avoids creating memref with + // zero stride. Note that we store the unmodified state into knownPtrs, since + // any following pointer arithmetic operations should use the original 0 + // stride. + auto accum_size = 1; + for (int i = state.sizes.size() - 1; i >= 0; i--) { + auto strideIntAttr = getIntAttr(state.strides[i]); + auto sizeIntAttr = getIntAttr(state.sizes[i]); + + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr && strideIntAttr.value() == 0) + state.strides[i] = rewriter.getIndexAttr(accum_size); + + accum_size *= sizeIntAttr.value(); + } + + Value src; + + if (llvm::any_of(state.modulos, [](auto mod) { return mod.has_value(); })) { + assert(state.modulos.size() == 2); + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + SmallVector casts; + StringRef type; + + if (!state.modulos[0].has_value() && state.modulos[1].has_value()) { + casts = state.createSideBySideCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundSideBySide; + } else if (state.modulos[0].has_value() && !state.modulos[1].has_value()) { + casts = state.createStackedCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundStacked; + } else { + assert(false && "not supported"); + } + + auto resultType = state.getResultMemrefType( + rewriter.getContext(), ShapedType::kDynamic, resultShape); + + UnrealizedConversionCastOp combinedCast = + rewriter.create( + op.getLoc(), resultType, + ValueRange{casts[0].getResult(), casts[1].getResult(), + op.getResult()}); + + combinedCast->setAttr(ModuloState::WraparoundAttr, + rewriter.getStringAttr(type)); + + src = combinedCast.getResult(0); + + LLVM_DEBUG({ + llvm::dbgs() << "combine cast for split pointers:\n"; + combinedCast.getOperation()->print( + llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + } else { + memref::ReinterpretCastOp castOp = + state.createCastOp(resultShape, op.getLoc(), rewriter); + + src = castOp.getResult(); + + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + } + + state.source = src; + rewriter.replaceOp(op, src); + rewriter.restoreInsertionPoint(origIp); +} + +void PtrAnalysis::rewriteAdvanceOp( + triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs) { + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + PtrState ptrState; + visitOperand(op.getOperand(0), ptrState, loc, rewriter, knownPtrs); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (auto [increment, offset, stride] : + llvm::zip(incrementOffsets, ptrState.offsets, ptrState.strides)) { + Value offsetValue; + if (auto offsetIntAttr = getIntAttr(offset)) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + offsetValue = constOp.getResult(); + } else { + offsetValue = cast(offset); + } + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), increment); + auto mulOp = rewriter.create(loc, castOp.getResult(), + cast(stride)); + auto addOp = + rewriter.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + ptrState.offsets.clear(); + + for (auto offset : newOffsets) { + ptrState.offsets.push_back(offset); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(ptrState.getRank() == 1); + } + + auto newOp = ptrState.createCastOp(resultShape, loc, rewriter); + + rewriter.replaceOp(op, newOp.getResult()); + + knownPtrs[newOp.getResult()] = ptrState; +} + +void PtrAnalysis::rewriteYieldOp( + scf::YieldOp op, ConversionPatternRewriter &rewriter, + const IndexMapSet &levelToBlockArgIndex, const int level, + const llvm::SmallDenseMap &knownPtrs) { + // any inserted instruction should be before this yield + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + + SmallVector initArgState; + SmallVector operands(adaptor.getOperands()); + // Track the second chunks of modulo pointers so that we can append them to + // the yield results + SmallVector moduloSecondChunks; + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // PtrState for those values. + for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { + if (auto mappedV = rewriter.getRemappedValue(v)) { + // If this value is a tensor of pointers produced by AddPtrOp, + // we should have already converted to a ReinterpretCastOp without + // layout information for the normal cases, or to an + // UnrealizedConversionCastOp for the split pointer case. + if (v.getDefiningOp() || + v.getDefiningOp() || + v.getDefiningOp()) { + if (auto castOp = mappedV.getDefiningOp()) { + assertValidUnrealizedCast(castOp); + auto castInputs = castOp.getInputs(); + v = castOp.getResult(0); + operands[i] = castInputs[0]; + moduloSecondChunks.push_back(castInputs[1]); + } else if (auto castOp = + mappedV.getDefiningOp()) { + v = castOp; + } else { + llvm_unreachable("mapped value defined by an unexpected op"); + } + } else { + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. + + // TODO: + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported + if (isa(mappedV.getType()) && + isa( + dyn_cast(mappedV.getType()).getElementType())) + llvm_unreachable("unsupported scenario where a value is a tensor of " + "pointers but not produced by AddPtrOp"); + v = mappedV; + } + } + + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) + continue; + auto thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto reintCastOp = v.getDefiningOp(); + auto unrealizedCastOp = v.getDefiningOp(); + + assert( + reintCastOp || + (unrealizedCastOp && + unrealizedCastOp->hasAttr(ModuloState::WraparoundAttr)) || + (isa(v.getType()) && + isa(dyn_cast(v.getType()).getElementType()))); + + PtrState state; + if (reintCastOp) { + visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, + knownPtrs); + } else if (unrealizedCastOp) { + assertValidUnrealizedCast(unrealizedCastOp); + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + knownPtrs); + } else { + visitOperand(v, state, op.getLoc(), rewriter, knownPtrs); + } + initArgState.push_back(state); + } + + // For each of the PtrState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.offsets) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. + if (auto sIntAttr = getIntAttr(s)) { + assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(cast(s)); + } + } + + for (auto s : state.strides) { + assert(!getIntAttr(s) && "PtrState strides for yield within for " + "loop not expected to be " + "attribute."); + operands.push_back(cast(s)); + } + } + + for (auto chunk : moduloSecondChunks) { + operands.push_back(chunk); + } + + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +// From an unrealized_conversion_cast which takes in two reinterpret_casts +// representing two chunks, we need to get back the full pointer state. We +// cannot rebuild the original state from the two reinterpret_casts similarly to +// the normal case. To solve this, we attach the original addptr as the third +// operand to the unrealized_cast so that we can manually rebuild the state. +void PtrAnalysis::visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assertValidUnrealizedCast(op); + + auto origPtr = op.getInputs()[2]; + if (knownPtrs.contains(origPtr)) { + state = knownPtrs.at(origPtr); + } else { + visitOperandAddptr(origPtr.getDefiningOp(), state, loc, + rewriter, knownPtrs); + } +} + +struct ModuloChunkInitArg { + Value reinterpretCast = nullptr; + // where in the init args is the first chunk placed + size_t initArgIndex = -1; +}; + +void PtrAnalysis::rewriteForOp( + scf::ForOp op, ConversionPatternRewriter &rewriter, + IndexMapSet &levelToBlockArgIndex, const int level, + llvm::SmallDenseMap &knownPtrs) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + // If we have a load op that uses a modulo pointer, we need to insert both of + // the memref chunks to the init args. We reuse the sizes from the original + // memrefs. This data structure keeps track of where these additional init + // args should be inserted. + // + // As an example, if we have a 2D memrefs being split, we first put the first + // chunk in the order as it appears. Then, once all of the original init args + // are processed, we insert their offsets and strides, and finally the second + // chunk. + SmallVector, PtrState>, + 6> + moduloStates; + + // Amongst the init args, track the indices that map to the first chunk of a + // modulo pair. This is used to distinguish between the normal + // reinterpret_casts whose return types need to be rewritten to match what the + // for loop is yielding. + DenseSet moduloInitArgIndices; + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = rewriter.getRemappedValue(arg); + memref::ReinterpretCastOp reintCastOp; + UnrealizedConversionCastOp unrealizedCastOp; + + // If this init arg is supposed to be remapped, use the remapped + // value instead. In addition, if this init arg is a memref created + // by a reinterpret_cast or a tensor of index, there is a chance that + // it will be used in addptr. Create PtrState for each such init arg. + if (mappedV) { + // TODO: + // Passing a block argument pointer directly into a for loop not + // supported. + assert(!(dyn_cast(mappedV) && + isa(mappedV.getType())) && + "cannot take pointer block argument as init arg for for loop"); + if (auto op = mappedV.getDefiningOp()) { + reintCastOp = op; + newInitArgs.push_back(mappedV); + } else if (auto op = + mappedV.getDefiningOp()) { + assertValidUnrealizedCast(op); + unrealizedCastOp = op; + auto inputs = unrealizedCastOp.getInputs(); + + SmallVector initArgData{ + ModuloChunkInitArg{inputs[0], i}, + ModuloChunkInitArg{inputs[1]}, + }; + + moduloInitArgIndices.insert(i); + moduloStates.push_back( + std::make_tuple(unrealizedCastOp, initArgData, PtrState{})); + + newInitArgs.push_back(inputs[0]); + } else { + newInitArgs.push_back(mappedV); + } + + } else { + newInitArgs.push_back(arg); + } + + auto indexTensor = + isa(arg.getType()) && + isa(dyn_cast(arg.getType()).getElementType()); + + if (!unrealizedCastOp && !reintCastOp && !indexTensor) + continue; + + PtrState state; + if (reintCastOp) { + visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } else if (unrealizedCastOp) { + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + std::get<2>(moduloStates.back()) = state; + } else { + visitOperand(arg, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } + + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + } + + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the PtrState recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.offsets[j] = constOp.getResult(); + } else { + newInitArgs.push_back(cast(s)); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.strides[j] = constOp.getResult(); + } else { + newInitArgs.push_back(cast(s)); + } + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, state)); + levelToBlockArgIndex[level].insert(i); + + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 + // * s1)>> + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + // + // For init args that are the first chunk of a modulo pair, there is + // no need for the type to be rewritten because the strides and + // offsets are already dynamic. + if (!moduloInitArgIndices.contains(i) && + newInitArgs[i].getDefiningOp()) { + SmallVector resultShape; + for (auto s : state.sizes) { + auto sIntAttr = getIntAttr(s); + assert(sIntAttr && "expected constant size"); + resultShape.push_back(sIntAttr.value()); + } + auto castOp = state.createCastOp(resultShape, op.getLoc(), rewriter); + + LLVM_DEBUG({ + llvm::dbgs() << "new reinterpret_cast with dynamic sizes " + "and offsets:"; + castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + newInitArgs[i] = castOp.getResult(); + } + } + + // Pass in the second chunk of each modulo pair + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + chunkData[1].initArgIndex = newInitArgs.size(); + newInitArgs.push_back(chunkData[1].reinterpretCast); + } + + rewriter.restoreInsertionPoint(origIp); + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = rewriter.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping mapping; + mapping.map(op.getInductionVar(), iv); + mapping.map(op.getInitArgs(), newInitArgs); + mapping.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, mapping); + } + + // Load op is lowered independent of the pointer, if we have a split + // pointer due to modulo, we need to "logically combine" these two + // memrefs into a single one using unrealized_cast_op. This way, when + // lowering the load, the pattern can detect if additional copies are + // inserted. When we are in a loop, it is more complicated because we + // have to insert a new unrealized_cast_op that combines the two memrefs + // in the init arg list. In addition, because init args hold no offset + // and size information, we have to manually insert two additional + // reinterpret_cast ops as input to this unrealized_cast_op so that the + // load have enough information to generate the corresponding copy. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(b.getBlock()); + + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + SmallVector newReinterpretCasts; + for (auto &chunk : chunkData) { + newReinterpretCasts.push_back(args[chunk.initArgIndex]); + } + + auto combinedCast = b.create( + loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, + unrealizedCastOp->getAttrs()); + + args[chunkData[0].initArgIndex].replaceUsesWithIf( + combinedCast.getResult(0), [](OpOperand &operand) { + assert(!isa(operand.getOwner()) && + "Storing to split pointers not supported"); + return isa(operand.getOwner()); + }); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrState fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + for (auto [i, state] : knownPtrsTmp) { + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs.insert(std::make_pair(key, state)); + } + assert(static_cast(cnt + moduloStates.size()) == + newOp.getRegionIterArgs().size() && + "expect to remap all new block args"); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + rewriter.replaceOp(op, resultsToReplaceWith); + + // Update the loop body. Manually invoke the rewrite logic on addptr and yield + // in the loop body, so we can take advantage of the states we built up + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto addptrOp = dyn_cast(bodyOp)) { + rewriteAddptrOp(addptrOp, rewriter, knownPtrs); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, knownPtrs); + } else if (auto forOp = dyn_cast(bodyOp)) { + // TODO: + // Nested for loops are not supported at the moment + assert(0 && "nested loops currently not supported"); + // rewriteForOp(forOp, rewriter, levelToBlockArgIndex, level+1, + // knownPtrs); levelToBlockArgIndex.erase(level+1); + } + } + + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + rewriteYieldOp(yieldOp, rewriter, levelToBlockArgIndex, level, knownPtrs); + } + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +Value PtrAnalysis::getScalarMemRef(Value ptr, Value memRef, const Location loc, + ConversionPatternRewriter &rewriter) { + assert(cast(ptr.getType()) && "expected scalar pointer"); + + // If the pointer is generated by tt.addptr, we will have already inserted an + // ReinterpretCastOp to cast its type from tt.ptr to unranked memref. Return + // the result. + if (ptr.getDefiningOp()) { + if (auto castOp = memRef.getDefiningOp()) { + return castOp.getResult(); + } else { + llvm_unreachable("pointer value is defined by an unexpected op"); + } + } + + assert(isa(ptr) && + "pointer is neither produced by addptr nor a block argument"); + PtrState state; + state.source = memRef; + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(1)); + state.strides.push_back(rewriter.getIndexAttr(1)); + state.modulos.push_back(std::nullopt); + auto castOp = state.createCastOp(SmallVector(1, 1), loc, rewriter); + return castOp.getResult(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp new file mode 100644 index 000000000..62e450808 --- /dev/null +++ b/third_party/tsingmicro/lib/Analysis/UseAnalysis.cpp @@ -0,0 +1,220 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/UseAnalysis.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace triton; +using namespace dataflow; + +#define DEBUG_TYPE "triton-use-analysis" + +//===----------------------------------------------------------------------===// +// Use Analysis +// Note that logic below should evolve with triton-to-affine pass +//===----------------------------------------------------------------------===// +LogicalResult +triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) { + // If an op only produces pointer, all its operands are used as meta data. + // This accounts for scenarios such as addptr in a loop whose result is + // yielded. In this case, if the loop returns data tensors, addptr will be + // marked correctly as meta use. + if (op->getResults().size() == 1) { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (resultType && isa(resultType.getElementType())) { + for (auto opnd : operands) + propagateUse(opnd, UseType::MetaUse); + } + } + + TypeSwitch(op) + .Case([&](auto load) { + propagateUse(operands[0], UseType::MetaUse); + auto mask = load.getMask(); + auto other = load.getOther(); + if (mask) { + assert(mask != other && "mask and other cannot be the same"); + propagateUse(operands[1], UseType::MetaUse); + } + if (other) { + // TODO: + // More complicated patterns that generate other is unsupported. + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto store) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = store.getValue(); + auto mask = store.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto dot) { + propagateResults(operands[0], results); + propagateResults(operands[1], results); + + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) + splat = opc.template getDefiningOp(); + + if (opc && splat && splat.getSrc().getDefiningOp()) + propagateUse(operands[2], UseType::MetaUse); + else + propagateUse(operands[2], UseType::DataUse); + }) + .Default([&](Operation *op) { + // this condition account for tt.addptr + for (auto operand : operands) { + propagateResults(operand, results); + } + }); + return success(); +} + +LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { + MLIRContext *context = funcOp.getContext(); + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(funcOp))) + return failure(); + + // Walk the func op, convert tags on operands to tags on operations + funcOp.walk([&](Operation *op) { + UseType useType = UseType::Undefined; + for (auto result : op->getResults()) { + auto use = solver.lookupState(result); + assert(use && "Lattice value not found"); + auto thisUseType = use->type; + if (thisUseType == UseType::Undefined) + continue; + if (useType == UseType::Undefined) + useType = thisUseType; + if (thisUseType == UseType::MixUse || thisUseType != useType) { + useType = UseType::MixUse; + break; + } + } + + if (useType == UseType::Undefined) { + LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); + return; + } else if (useType == UseType::MetaUse) { + assert(op->getNumResults() == 1 && + "Ops used for meta computation are expected to have one result"); + // Only set the tag if the operation uses tensors + if (isa(op->getResult(0).getType())) { + // Setting tag for erasing op later + op->setAttr("MetaUse", UnitAttr::get(context)); + } + return; + } else if (useType == UseType::DataUse) { + LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); + return; + } + + assert(useType == UseType::MixUse); + + // If the operation only produces scalars, no need to clone it + bool shapedResult = true; + for (auto result : op->getResults()) + shapedResult &= isa(result.getType()); + if (!shapedResult) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Value has MixUse. However, the operation may or may not have direct + // MetaUse. E.g., it may only have MixUse, or only have MixUse and + // DataUse. + // - If the operation has direct MetaUse, clone it, tag the clone as + // MetaUse only and point meta users to use the clone. + // - If not, do nothing; this operation will still be materlized. + llvm::SetVector metaUsers; + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + TypeSwitch(user) + .Case([&](auto load) { + auto ptr = load.getPtr(); + auto mask = load.getMask(); + auto other = load.getOther(); + if (result == ptr || result == mask || result == other) + metaUsers.insert(user); + }) + .Case([&](auto store) { + auto ptr = store.getPtr(); + auto mask = store.getMask(); + if (result == ptr || result == mask) + metaUsers.insert(user); + }) + .Case([&](auto dot) { + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) + splat = opc.template getDefiningOp(); + + if (opc && splat && + splat.getSrc().getDefiningOp()) + metaUsers.insert(user); + }) + .Default([&](Operation *op) { + // if all output of user are used as meta data, user is a meta + // user. This condition account for addptr, or an addi whose + // output only feeds into addptr + bool allMeta = true; + for (auto res : op->getResults()) { + auto resUse = solver.lookupState(res); + if (resUse->type != UseType::MetaUse) { + allMeta = false; + break; + } + } + if (allMeta) + metaUsers.insert(user); + }); + } + } + + // If the operation doesn't have direct meta users, no need to clone it + if (metaUsers.empty()) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Clone the operation; switch all meta users to use the clone + OpBuilder builder(op); + auto clone = builder.clone(*op); + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + + // Setting tag for erasing op later + clone->setAttr("MetaUse", UnitAttr::get(context)); + + for (auto [res_i, result] : llvm::enumerate(op->getResults())) + for (auto user : metaUsers) + for (auto &operand : user->getOpOperands()) + if (operand.get() == result) + operand.set(clone->getResult(res_i)); + }); + + return success(); +} diff --git a/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt b/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt new file mode 100644 index 000000000..0683754ca --- /dev/null +++ b/third_party/tsingmicro/lib/AnalysisStructured/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ZTCAnalysisStructured + PtrAnalysis.cpp + + DEPENDS + TritonAnalysis + TritonTableGen + TritonStructuredTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonStructuredIR + MLIRAnalysis +) diff --git a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp new file mode 100644 index 000000000..d813ffdb2 --- /dev/null +++ b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp @@ -0,0 +1,1395 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Analysis/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "triton-ptr-analysis" + +namespace mlir { + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + builder, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +namespace tts { + +int32_t PtrState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + shape.size() == offsets.size()); + return offsets.size(); +} + +bool PtrState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +bool PtrState::hasModulo() const { + for (int32_t i = 0; i < getRank(); i++) { + if (dimHasModulo(i)) { + return true; + } + } + return false; +} + +bool PtrState::dimHasModulo(uint32_t dim) const { + assert( + !isBlockPtr() && + "Analysis should not check modulo if PtrState describes block pointer"); + + assert(dim < getRank()); + + auto intAttr = getIntAttr(shape[dim]); + if (!intAttr.has_value()) { + return true; + } + + return intAttr.value() != 0; +} + +bool PtrState::isBlockPtr() const { return !order.empty(); } + +LogicalResult PtrState::addState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + auto loc = op->getLoc(); + + if (lhsState.source && rhsState.source) { + op->emitRemark( + "PtrAnalysis: do not support adding two pointer states that both " + "have base pointers"); + return failure(); + } + + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.getRank(); i++) { + auto newOffset = + addOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, builder); + offsets.push_back(newOffset); + + auto newStride = + addOFRs(lhsState.strides[i], rhsState.strides[i], loc, builder); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + } + + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark("PtrAnalysis: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + if (lhsState.hasModulo() || rhsState.hasModulo()) { + // visitOperandSplat and visitOperandExpandDims should enforce below + assert(lhsState.getRank() <= 2); + } + + // dealing with modulo: + // - If lhs has no modulo, skip + // - If rhs has zero offset on dim i, we can just use lhs's modulo + // - If i == 0 and rhs is the result of a splat, we will allow the add. This + // is because the user may be trying to express adding a constant offset to + // increment dim1, but pointer analysis cannot differentiate dim1 vs dim0 in + // this case. + // - Else, the analysis fails + + // An example for the 3rd condition above can look like: + // %0 = tt.splat %scalar + // %1 = tt.splat %ptr + // %2 = tt.arange + // %3 = arith.remsi %2, %size + // %4 = tt.addptr %1, %3 + // %5 = tt.addptr %4, %0 + // %5 may also occur in a loop to increment %4 every iteration. + + // Note that this is not bullet-proof. E.g., broken IR can actually increment + // dim0 while dim0 already has modulo, since Triton offsets are element-wise + // and not in unit of lower dimensions. However, this is highly unlikely but + // the analysis will provide wrong result. Hence we provide a warning in this + // case. + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (rhs->hasModulo()) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->getRank(); i++) { + if (!lhs->dimHasModulo(i)) { + shape.push_back(lhs->shape[i]); + } else if (hasConstZero(rhs->offsets[i])) { + shape.push_back(lhs->shape[i]); + } else if (i == 0 && lhs->getRank() == 2 && rhs->scalar) { + shape.push_back(lhs->shape[1]); + shape.push_back(lhs->shape[0]); + op->emitWarning( + "PtrAnalysis: allowing adding pointer state with modulo in dim 0 to " + "another pointer state with offset in dim 0.\nPlease verify the " + "operand that contains a scalar is meant to increment pointers in " + "dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION " + "RESULTS.\n\nTo avoid this warning, use expand_dims (instead of " + "splat) to explicitly specify which dimension contains the scalar."); + break; + } else { + op->emitRemark( + "PtrAnalysis: do not support adding to operand with modulo"); + return failure(); + } + } + + return success(); +} + +void PtrState::dump() const { + llvm::dbgs() << "PtrState: "; + if (source) { + llvm::dbgs() << "source: " << source << "\n"; + } + if (scalar) { + llvm::dbgs() << "scalar: " << scalar << "\n"; + } + + llvm::dbgs() << "offsets: "; + llvm::interleave(offsets, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nstrides: "; + llvm::interleave(strides, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nsizes: "; + llvm::interleave(sizes, llvm::dbgs(), "\n"); + llvm::dbgs() << "\nshape: "; + llvm::interleave(shape, llvm::dbgs(), "\n"); + llvm::dbgs() << "\norder: "; + llvm::interleave(order, llvm::dbgs(), "\n"); + llvm::dbgs() << "\n"; +} + +LogicalResult PtrState::mulState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + auto loc = op->getLoc(); + + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + if (lhsState.source && rhsState.source) { + op->emitRemark("PtrAnalysis: do not support multiplying base pointers"); + return failure(); + } + + // currently do not support both tensors are effectively non-scalar + if (!lhsState.scalar && !rhsState.scalar) { + op->emitRemark( + "PtrAnalysis: only support multiplying pointer states when one of " + "them represent a scalar"); + return failure(); + } + + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + if (lhsState.scalar && rhsState.scalar) { + scalar = builder.create( + loc, lhsState.scalar, rhsState.scalar); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, builder); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, builder); + OpFoldResult newShape = + mulOFRValue(lhs->shape[i], rhs->scalar, loc, builder); + offsets.push_back(newOffset); + strides.push_back(newStride); + shape.push_back(newShape); + sizes.push_back(lhs->sizes[i]); + } + + if (rhs->hasModulo()) { + op->emitRemark( + "PtrAnalysis: do not support multiplying pointer states that has " + "modulos"); + return failure(); + } + + return success(); +} + +tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc) { + SmallVector staticSizes; + for (size_t i = 0; i < getRank(); i++) { + auto s = getIntAttr(sizes[i]); + assert(s.has_value()); + staticSizes.push_back(s.value()); + } + + auto op = builder.create( + loc, source, staticSizes, strides, offsets, shape, order); + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::make_tensor_ptr:\n"; + op->dump(); + }); + + return op; +} + +LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrState &state, + const Location loc, + OpBuilder &builder) { + PtrState lhsState; + if (visitOperand(addOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrState rhsState; + if (visitOperand(addOp.getRhs(), rhsState, loc, builder).failed()) + return failure(); + + // Checking for higher dimension is done in addState below + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + addOp->emitRemark( + "PtrAnalysis: do not support this pattern: a + arange(0, K) % M"); + return failure(); + } + + return state.addState(lhsState, rhsState, addOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrState &state, + const Location loc, + OpBuilder &builder) { + PtrState lhsState; + if (visitOperand(mulOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrState rhsState; + if (visitOperand(mulOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + return state.mulState(lhsState, rhsState, mulOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrState rhsState; + if (visitOperand(remOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + if (!rhsState.scalar) { + remOp->emitRemark("PtrAnalysis: only support cases when rhs of remainder " + "contains scalar"); + return failure(); + } + + if (visitOperand(remOp.getLhs(), state, loc, builder).failed()) { + return failure(); + } + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + if (state.hasModulo()) { + remOp->emitRemark( + "PtrAnalysis: do not support multiple modulo within an expression"); + return failure(); + } + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.shape.back() = rhsState.scalar; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.shape[1] = rhsState.scalar; + } else if (shape[1] == 1) { + state.shape[0] = rhsState.scalar; + } else { + remOp->emitRemark( + "PtrAnalysis: taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + return failure(); + } + } else { + remOp->emitRemark("PtrAnalysis: unsupported modulo pattern"); + return failure(); + } + return success(); +} + +LogicalResult PtrAnalysis::visitOperandExtSI(arith::ExtSIOp extOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + return visitOperand(extOp.getIn(), state, loc, builder); +} + +LogicalResult PtrAnalysis::visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back(builder.getIndexAttr(start)); + state.sizes.push_back(builder.getIndexAttr(shape[0])); + state.strides.push_back(builder.getIndexAttr(stride)); + state.shape.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + if (visitOperand(expandDimsOp.getSrc(), state, loc, builder).failed()) { + return failure(); + } + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, builder.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, builder.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, builder.getIndexAttr(0)); + state.shape.insert(state.shape.begin() + axis, builder.getIndexAttr(0)); + + if (state.hasModulo() && state.getRank() > 2) { + expandDimsOp->emitRemark( + "PtrAnalysis: unsupported scenario where expand_dims result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + + if (!isa(src.getType())) { + broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); + return failure(); + } + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + for (size_t i = 0; i < dstShape.size(); i++) { + if (srcShape[i] == dstShape[i]) { + continue; + } else if (srcShape[i] < dstShape[i]) { + state.sizes[i] = builder.getIndexAttr(dstShape[i]); + } else { + llvm_unreachable("unexpected dimensions used in broadcast"); + } + } + return success(); +} + +LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + if (isa(src.getType())) { + for (auto s : dstShape) { + state.offsets.push_back(builder.getIndexAttr(0)); + state.sizes.push_back(builder.getIndexAttr(s)); + state.strides.push_back(builder.getIndexAttr(0)); + state.shape.push_back(builder.getIndexAttr(0)); + } + } else { + splatOp->emitRemark("PtrAnalysis: unsupported splat pattern"); + return failure(); + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) + state.offsets[0] = state.scalar; + + if (state.hasModulo() && state.getRank() > 2) { + splatOp->emitRemark("PtrAnalysis: unsupported scenario where splat result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrState ptrState; + if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) + .failed()) { + // assert(0); + return failure(); + } + + PtrState offsetState; + if (visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), + builder) + .failed()) { + return failure(); + } + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + return state.addState(ptrState, offsetState, addptrOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandConstSplat(arith::ConstantOp op, + PtrState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType)); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (size_t i = 0; i < resultType.getShape().size(); i++) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(builder.getIndexAttr(0)); + } + + state.sizes.push_back(builder.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(builder.getIndexAttr(0)); + state.shape.push_back(builder.getIndexAttr(0)); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrState &state, + const Location loc, + OpBuilder &builder) { + + assert(state.isEmpty()); + state.source = makeTPtrOp.getBase(); + state.offsets = makeTPtrOp.getMixedOffsets(); + state.sizes = makeTPtrOp.getMixedSizes(); + state.strides = makeTPtrOp.getMixedStrides(); + state.shape = makeTPtrOp.getMixedShape(); + state.order = SmallVector(makeTPtrOp.getOrder()); + + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + state.source = makeTPtrOp.getBase(); + + if (makeTPtrOp.getOrder().empty()) { + makeTPtrOp->emitRemark( + "PtrAnalysis: expect tt.make_tensor_ptr to have order field set"); + return failure(); + } + + auto resType = cast(makeTPtrOp.getResult().getType()); + auto pointeeType = cast(resType.getPointeeType()); + auto shape = pointeeType.getShape(); + + for (int64_t i = 0; i < pointeeType.getRank(); i++) { + state.sizes.push_back(builder.getIndexAttr(shape[i])); + + auto strideCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + state.strides.push_back(strideCst.getResult()); + + auto offsetCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + + auto scaledOffset = builder.create( + loc, offsetCst.getResult(), strideCst.getResult()); + state.offsets.push_back(scaledOffset.getResult()); + + auto shapeCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); + state.shape.push_back(shapeCst.getResult()); + } + state.order = SmallVector(makeTPtrOp.getOrder()); + assert(state.isBlockPtr() && + "tt.make_tensor_ptr pointer state should describe a block pointer"); + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandForOp(scf::ForOp forOp, Value operand, + PtrState &state, + const Location loc, + OpBuilder &builder) { + + auto it = llvm::find(forOp->getResults(), operand); + auto index = std::distance(forOp->getResults().begin(), it); + + auto newState = getLoopResultPtrState(forOp, index); + if (failed(newState)) { + forOp.emitError( + "Rewrite for-op failed. Could not find PtrState returned by " + "the loop."); + return failure(); + } + + state = newState.value(); + return success(); +} + +LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, + const Location loc, + OpBuilder &builder) { + + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return success(); + } + + if (isa(operand.getType())) { + OpBuilder::InsertionGuard guard(builder); + if (!isa(operand) && operand.getDefiningOp()) { + builder.setInsertionPointAfter(operand.getDefiningOp()); + } + auto castOp = builder.create( + loc, builder.getIndexType(), operand); + state.scalar = castOp.getResult(); + return success(); + } else if (isa(operand.getType())) { + state.scalar = operand; + return success(); + } + + if (isa(operand.getType())) { + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + return visitOperandAddptr(cast(op), state, loc, + builder); + } else if (auto makeTensorOp = dyn_cast(op)) { + llvm_unreachable("Unexpected operand defining operation tts.make_tptr"); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } + } else { + state.source = operand; + return success(); + } + } + + if (auto op = operand.getDefiningOp()) { + return visitOperandAdd(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMul(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMakeRange(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandBroadcast(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandExpandDims(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandAddptr(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandConstSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandRem(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandExtSI(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandForOp(op, operand, state, loc, builder); + } else if (!operand.getDefiningOp()) { + if (!knownPtrs.contains(operand)) { + return failure(); + } + + // This operand must be an iter-arg of an inner-loop in a multiple-level + // nested loop, which means its PtrState must have already been populated + // during rewriteForOp of the parent loop. + state = knownPtrs[operand]; + return success(); + } else { + llvm::dbgs() << "PtrAnalysis: encountered addptr operand produced by an " + "unsupported operation\n"; + operand.dump(); + return failure(); + } +} + +LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (visitOperandAddptr(op, state, op.getLoc(), builder).failed()) { + return failure(); + } + + knownPtrs[op.getResult()] = state; + + if (isa(op.getPtr().getType())) { + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(op.getResult(), maketptrOp.getResult()); + } else { + // record the ptr as we have visited and built up the state for this scalar + // pointer, which may be used by rewriteForOp later. + ptrMap.map(op.getResult(), op.getResult()); + } + return success(); +} + +LogicalResult PtrAnalysis::rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (visitOperandMakeTensorPtr(op, state, op.getLoc(), builder).failed()) { + return failure(); + } + + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + knownPtrs[op.getResult()] = state; + ptrMap.map(op.getResult(), maketptrOp.getResult()); + return success(); +} + +LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) { + OpBuilder builder(op); + auto loc = op.getLoc(); + + PtrState state; + if (visitOperand(op->getOperand(0), state, loc, builder).failed()) { + op->emitRemark("PtrAnalysis: Failed to analyze ptr of tt.advance"); + return failure(); + } + assert(state.isBlockPtr() && + "tt.advance pointer state should describe a block pointer"); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (auto [increment, offset, stride] : + llvm::zip(incrementOffsets, state.offsets, state.strides)) { + Value offsetValue; + if (auto offsetIntAttr = getIntAttr(offset)) { + auto constOp = builder.create( + loc, builder.getIndexAttr(offsetIntAttr.value())); + offsetValue = constOp.getResult(); + } else { + offsetValue = cast(offset); + } + auto castOp = builder.create( + loc, builder.getIndexType(), increment); + auto mulOp = builder.create(loc, castOp.getResult(), + cast(stride)); + auto addOp = + builder.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + state.offsets = SmallVector(newOffsets); + + auto newOp = state.createTTSMakeTensorPtrOp(builder, loc); + knownPtrs[op.getResult()] = state; + ptrMap.map(op.getResult(), newOp.getResult()); + return success(); +} + +static bool isPointerType(Type t) { + if (auto tensor = llvm::dyn_cast(t)) { + return isa(tensor.getElementType()); + } + return isa(t); +} + +FailureOr PtrAnalysis::getLoopInitArgPtrState(scf::ForOp forOp, + size_t index) { + auto ptr = forOp.getInitArgs()[index]; + + // If the pointer into the scf.for was defined by tts.get_structured_state, + // we can get the pointer state from the original pointer (the op's input): + // + // %ptr, %offset_1, %offset_2,..., %stride_1, %stride_2,... = + // tts.get_structured_state %original + // scf.for ... (%ptr) {...} + if (auto getStateOp = ptr.getDefiningOp()) { + auto originalPtr = getStateOp->getOperand(0); + if (knownPtrs.count(originalPtr)) { + return knownPtrs[originalPtr]; + } + } + + // For nested loops scenarios, a pointer in init-args can be returned from + // another loop of the same level: + // e.g.: + // clang-format off + // %22:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + // %23 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %arg5) -> (tensor<2x2x!tt.ptr>) : i32 { + // %26 = tt.addptr %arg8, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + // scf.yield %26 : tensor<2x2x!tt.ptr> + // } + // %24:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %23, %arg9 = %arg6) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + // %26 = tt.load %arg8 : tensor<2x2x!tt.ptr> + // %27 = tt.addptr %arg8, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + // ... + // } + // ... + // } + // clang-format on + // Notice %arg8 = %23 comes from the return value of the first loop. + if (auto forOp = ptr.getDefiningOp()) { + return getLoopResultPtrState(forOp, index); + } + + // If the pointer isn't defined by tts.get_structured_state nor another loop, + // it means the current pointer is an iterarg of the outer loop. + // In such cases, the outer loops would have already set up the PtrState for + // us already. + // + // scf.for iterargs(%ptr = %init_arg) { + // scf.for iterargs(%ptr1 = %ptr) { <--- we're dealing with `%ptr1` here. + // ... + // } + // } + if (knownPtrs.count(ptr)) { + assert(!ptr.getDefiningOp() && "Expect the ptr to be an iterarg"); + return knownPtrs[ptr]; + } + + return failure(); +} + +PtrState PtrAnalysis::reconcileLoopPtrState( + scf::ForOp forOp, size_t iterArgIndex, const PtrState &state, + llvm::function_ref getReplacementVal) { + PtrState newState = state; + int cnt = iterArgIndex + 1; + if (newState.getRank() == 0) { + assert(newState.scalar); + // for scalar pointers, the scalar contains the offset and is the only + // relevant newState that could be updated by the loop. + newState.scalar = getReplacementVal(forOp, cnt); + } else { + for (auto &offset : newState.offsets) { + offset = getReplacementVal(forOp, cnt++); + } + + for (auto &stride : newState.strides) { + stride = getReplacementVal(forOp, cnt++); + } + } + + return newState; +} + +FailureOr PtrAnalysis::getLoopIterArgPtrState(scf::ForOp forOp, + size_t index) { + auto state = getLoopInitArgPtrState(forOp, index); + if (failed(state)) { + return failure(); + } + + return reconcileLoopPtrState( + forOp, index, state.value(), + [](scf::ForOp op, size_t index) { return op.getRegionIterArg(index); }); +} + +FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, + size_t index) { + auto state = getLoopInitArgPtrState(forOp, index); + if (failed(state)) { + return failure(); + } + + return reconcileLoopPtrState( + forOp, index, state.value(), + [](scf::ForOp op, size_t index) { return op->getResult(index); }); +} + +LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { + for (auto [i, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!maybeStructuredArgs.contains(arg)) { + continue; + } + + auto state = getLoopIterArgPtrState(op, i); + if (failed(state)) { + // Because the maybeStructuredArgs may contain values that are not + // considered structured by PtrAnalysis, failing to retrieve the PtrState + // should not fail the rewrite process. + // We emit an error for diagnostics and debugging purposes. + op->emitWarning( + "Rewrite for-op failed. Could not find PtrState for iter-arg index " + + std::to_string(i)); + continue; + } + + // Save the current init arg's PtrState + knownPtrs[arg] = state.value(); + + // For tensors of pointers, create a tts.make_tptr at the beginning of the + // loop body that correspond to this region iter arg. In case it is used + // by tt.load/tt.store in the loop body before pointer updates, this will + // make sure rewriteLoadOp/rewriteStoreOp can use the analysis result. + // E.g., given the following input (%tensor_of_ptr is a block arg): + // scf.for (%tensor_of_ptr) { + // %data = tt.load %tensor_of_ptr + // // more operations to update %tensor_of_ptr + // } + // We may produce the following output: + // scf.for (%base_ptr, %stride, %offset) { + // %tensor_of_ptr = tts.make_tptr(%base_ptr, %stride, %offset) + // %data = tts.load %tensor_of_ptr + // // more operations to update %offset + // } + // If %tensor_of_ptr is not used (i.e., %tensor_of_ptr is updated before + // used in the original IR), it will simply be removed by + // canonicalization. + + // For scalar pointers, there is no need to create a tts.addptr at the + // beginning of the loop body. We don't lower tt.load and tt.store on + // scalars in this pass; pointer arithmetics can also just use the + // original pointer. + // Note that there can be tensor of indices in iter-arg, so we only create + // the make_tensor_ptr op when the arg is of pointer type. + if (isPointerType(arg.getType())) { + if (state->getRank() != 0) { + OpBuilder builder(op.getRegion()); + auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(arg, maketptrOp.getResult()); + } + } + } + + // Recursively rewrite the inner ops + if (rewriteOp(op).failed()) { + op->emitRemark( + "PtrAnalysis: update loop body failed when rewriting for op"); + return failure(); + } + + return success(); +} + +LogicalResult +PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { + auto tritonValue = op->getOperand(0); + + // If this triton value isn't known, it means PtrAnalysis has failed to + // analyze this pointer. In such cases, simply remap all uses of the + // structured value back to its original triton value. + if (!knownPtrs.contains(tritonValue)) { + op.emitRemark( + "Rewrite GetStructuredStateOp failed. Could not find PtrState."); + op.getResult(0).replaceAllUsesWith(tritonValue); + return failure(); + } + + tts::PtrState state = knownPtrs[tritonValue]; + Value remappedValue = + ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue; + + SmallVector replacements{remappedValue}; + OpBuilder builder(op); + + if (state.getRank() == 0) { + // For scalar pointers, the scalar contains the offset and is the only + // relevant state that could be updated by the loop. + if (state.scalar) { + replacements.push_back(state.scalar); + } else { + // This operand is a pointer directly from the kernel arguments. + // Use offset 0. + assert(!tritonValue.getDefiningOp()); + replacements.push_back(builder.create( + op.getLoc(), builder.getIndexAttr(0))); + } + } else { + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + replacements.push_back(constOp.getResult()); + } else { + replacements.push_back(cast(s)); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + replacements.push_back(constOp.getResult()); + } else { + replacements.push_back(cast(s)); + } + } + } + + op->replaceAllUsesWith(replacements); + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, + bool useUnsafeMask) { + auto ptr = ptrMap.lookupOrNull(op.getPtr()); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "loadOp cannot be rewritten"); + return failure(); + } + + auto ptrType = dyn_cast(ptr.getType()); + if (ptrType && !isa(ptrType.getPointeeType())) { + op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + mlir::triton::MaskState mstate(useUnsafeMask); + Value scalarOther; + + OpBuilder builder(op); + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + if (other) { + assert(mask && "other value used while no masks are specified"); + + scalarOther = getScalarValue(other, loc, builder); + if (!scalarOther) { + op->emitRemark("other value used in masked load produced by " + "unsupported instruction"); + return failure(); + } + } + + auto loadOp = builder.create(loc, ptr, dims, scalarOther); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::load:\n"; + loadOp->dump(); + }); + + op.replaceAllUsesWith(loadOp.getResult()); + op->erase(); + return success(); +} + +// Structured values from the TritonStructuredDialect have offsets and strides +// that might change in each loop iteration and hence will appear in an scf.for +// iter-args like so: +// +// %structured, %offsets, %strides = tts.get_structured_state +// scf.for (%arg0 = %structured, %arg1 = %offsets, %arg2 = %strides) { +// %a = %arg0 + 1 +// %b = %b + 2 +// scf.for (%arg1 = %b) { +// ... +// } +// } +// +// In `rewriteForOp`, we have to recognize such structured values in order to +// rewrite their PtrState accordingly. Previously, only values of Pointer-like +// type (e.g.: tensor> or tt.ptr>), so detecting these values +// is as easy as checking the type. +// +// Now, tensor of indices could also appear in a loop's iter-arg. To reliably +// detect all such cases, we perform a BFS-like traversal of the IR where the +// sources are the results of `tts.get_structured_state`. All values that +// originate from the results of `tts.get_structured_state` are consider +// "maybeStructured". If a loop's iter-arg is considered "maybeStructured", we +// must set up their PtrState during `rewriteForOp`. +void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) { + std::queue q; + DenseSet visited; + + op->walk([&q, &visited](tts::GetStructuredStateOp getStateOp) { + Value value = getStateOp->getResult(0); + visited.insert(value); + q.push(value); + }); + + while (!q.empty()) { + auto v = q.front(); + q.pop(); + for (auto user : v.getUsers()) { + // scf.for is a special case. We have 2 set of values to consider: + // - iter-args + // - loop results + // for every init arg that originates from a `tts.get_structured_state` + // op, its corresponding iter-arg and loop result will also be considered + // "maybeStructured". + if (auto forOp = dyn_cast(user)) { + auto it = llvm::find(forOp.getInitArgs(), v); + + if (it == forOp.getInitArgs().end()) { + continue; + } + + auto argIndex = std::distance(forOp.getInitArgs().begin(), it); + auto iterArg = forOp.getRegionIterArg(argIndex); + auto tiedLoopRes = forOp.getTiedLoopResult(iterArg); + + SmallVector neighbors{iterArg, tiedLoopRes}; + for (auto neighbor : neighbors) { + maybeStructuredArgs.insert(neighbor); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + q.push(neighbor); + } + } + + } else { + for (auto res : user->getResults()) { + if (res.getType() != v.getType()) { + continue; + } + maybeStructuredArgs.insert(res); + if (!visited.contains(res)) { + visited.insert(res); + q.push(res); + } + } + } + } + } +} + +LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, + bool useUnsafeMask) { + auto ptr = ptrMap.lookupOrNull(op.getPtr()); + auto val = op.getValue(); + auto mask = op.getMask(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "storeOp cannot be rewritten"); + return failure(); + } + + auto ptrType = dyn_cast(ptr.getType()); + if (ptrType && !isa(ptrType.getPointeeType())) { + op->emitRemark("PtrAnalysis: scalar storeOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + mlir::triton::MaskState mstate(useUnsafeMask); + + OpBuilder builder(op); + + // Analyze the mask operand to determine at runtime the size of the data + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + auto storeOp = builder.create(loc, ptr, val, dims); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::store:\n"; + storeOp->dump(); + }); + + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { + LLVM_DEBUG({ + llvm::dbgs() << "rewriting rootOp\n"; + rootOp->dump(); + }); + + rootOp->walk([&](Operation *op) { + if (op == rootOp) { + return WalkResult::advance(); + } + return TypeSwitch(op) + .Case([&](auto addptr) { + if (rewriteAddptrOp(addptr).failed()) { + addptr->emitRemark("PtrAnalysis: Failed to rewrite AddPtrOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto maketptr) { + if (rewriteMakeTensorPtrOp(maketptr).failed()) { + maketptr->emitRemark( + "PtrAnalysis: Failed to rewrite MakeTensorPtrOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto advance) { + if (rewriteAdvanceOp(advance).failed()) { + advance->emitRemark("PtrAnalysis: Failed to rewrite AdvanceOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto load) { + if (rewriteLoadOp(load, useUnsafeMask).failed()) { + load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto store) { + if (rewriteStoreOp(store, useUnsafeMask).failed()) { + store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto forOp) { + // `rewriteForOp` recursively visits its children, so regardless + // whether the rewrite succeeds or not, we need to return "skip" so + // that the the walk does not visit the for-op's child operations + // the second time. + if (rewriteForOp(forOp).failed()) { + forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); + } + return WalkResult::skip(); + }) + .Case( + [&](tts::GetStructuredStateOp getStateOp) { + // For tensor of indices potentially being used in pointer + // arithmetic sequence, we need to manually populate the state of + // none already exists. + // This process is necessary because unlike triton pointers in a + // loop which always have a `tt.addptr` that triggers the rewrite + // process which includes generating the ops for updating offsets + // and strides, tensor of indices only have a simple `arith.addi` + // (or other arith ops). + // Without visiting these ops manually, the ops to update the + // offsets and strides would not be generated. + auto tritonValue = getStateOp->getOperand(0); + if (!knownPtrs.contains(tritonValue)) { + PtrState state; + OpBuilder b(getStateOp); + if (succeeded(visitOperand(tritonValue, state, + getStateOp->getLoc(), b))) { + knownPtrs[tritonValue] = state; + } else { + getStateOp->emitRemark("PtrAnalysis: Failed to populate ptr " + "state for tensor of indices"); + } + } + + return WalkResult::skip(); + }) + .Default([&](auto) { return WalkResult::advance(); }); + }); + + return success(); +} + +} // namespace tts +} // namespace mlir diff --git a/third_party/tsingmicro/lib/CMakeLists.txt b/third_party/tsingmicro/lib/CMakeLists.txt new file mode 100644 index 000000000..eff85b208 --- /dev/null +++ b/third_party/tsingmicro/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Analysis) +add_subdirectory(AnalysisStructured) +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/tsingmicro/lib/Conversion/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..6177901f5 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt @@ -0,0 +1,10 @@ +add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) +add_subdirectory(TritonArithToLinalg) +add_subdirectory(StructuredToMemref) +add_subdirectory(Tx81MemrefToLLVM) +add_subdirectory(LinalgToMK) +add_subdirectory(MKToTx81) +add_subdirectory(Tx81ToLLVM) +add_subdirectory(TritonToCoreDialects) +add_subdirectory(CoreDialectsToMK) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt new file mode 100644 index 000000000..c4d27c2b3 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CMakeLists.txt @@ -0,0 +1,23 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +# All rights reserved. +# +#===------------------------------------------------------------------------===# + +add_triton_library(CoreDialectsToMK + CoreDialectsToMKPass.cpp + + DEPENDS + CoreDialectsToMKConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport +) diff --git a/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp new file mode 100644 index 000000000..a9ffab977 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/CoreDialectsToMK/CoreDialectsToMKPass.cpp @@ -0,0 +1,60 @@ +//===------------------- CoreDialectsToMKPass.cpp -------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering core dialects to backend dialects +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/CoreDialectsToMK/CoreDialectsToMK.h" +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h.inc" + +namespace { + +class CoreDialectsToMKPass : public CoreDialectsToMKBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + PassManager pm(&getContext(), moduleOp.getOperationName()); + + pm.addPass(createLinalgToMKPass()); + + // Erase dead code and fold constants created during lowering + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> triton::createCoreDialectsToMKPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt new file mode 100644 index 000000000..409f42d88 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(LinalgToMagicKernel + LinalgToMK.cpp + LinalgToMKPass.cpp + + DEPENDS + MagicKernelTableGen + LinalgToMKConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp new file mode 100644 index 000000000..b088a56ea --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -0,0 +1,58 @@ +//===------------------- LinalgToMK.cpp -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" + +#define DEBUG_TYPE "linalg-to-mk" + +using namespace mlir; +using namespace mk; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" + + +namespace { + +// Convert tensor.empty + linalg.fill + linalg.matmul to mk.matmul +struct MatmulConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(linalg::MatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); +#if 0 + auto tensorType = *op->getResultTypes().begin(); + + Value output = op->getResult(0); + auto fillOp = output.getDefiningOp(); + Value emptyTensor = fillOp->getResult(0); + auto tensorEmptyOp = emptyTensor.getDefiningOp(); + + auto dotOp = rewriter.create(loc, tensorType, op->getOperand(0), + op->getOperand(1), + op.getNumOperands() == 3 ? op->getOperand(2) : nullptr); + rewriter.replaceOp(op, dotOp); +#endif + return success(); + } +}; + +} // namespace + +void mlir::triton::populateLinalgToMKCanonicalizationPatterns( + RewritePatternSet &patterns) { +} + +void mlir::triton::populateLinalgToMKConversionPatterns( + RewritePatternSet &patterns) { + // patterns.add(patterns.getContext()); +} \ No newline at end of file diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp new file mode 100644 index 000000000..094a51b31 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp @@ -0,0 +1,72 @@ +//===------------------- LinalgToMKPass.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include +#include +#include + +#define DEBUG_TYPE "linalg-to-mk" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_LINALGTOMK +#include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class LinalgToMKPass : public triton::impl::LinalgToMKBase { + using LinalgToMKBase::LinalgToMKBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + // TODO: Enable this when all conversion pattern has been implemented. + //target.addIllegalDialect(); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + mk::MagicKernelDialect>(); + + target.addLegalOp(); + + triton::populateLinalgToMKConversionPatterns(patterns); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +triton::createLinalgToMKPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt new file mode 100644 index 000000000..b7e951a04 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(MKToTx81 + MKToTx81.cpp + MKToTx81Pass.cpp + + DEPENDS + Tx81TableGen + MKToTx81ConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp new file mode 100644 index 000000000..356d16cf6 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -0,0 +1,642 @@ +//===--------------------- MKToTx81.cpp -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements the patterns to convert operations from mk dialect to +// tx81 dialect. It converts memory operations to RdmaOp/WdmaOp and converts +// mk.dot to tx.gemm etc. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" +#include "Tx81/instr_def.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "mk-to-tx81" + +using namespace mlir; +using namespace tx; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Type Conversion +//===----------------------------------------------------------------------===// + +class MKToTx81TypeConverter : public TypeConverter { +public: + MKToTx81TypeConverter() { + // Add conversions for MemRef types to UI64 (representing SPM addresses) + addConversion([](MemRefType type) -> Type { + return IntegerType::get(type.getContext(), 64, IntegerType::Unsigned); + }); + + // Add conversions for Tensor types to UI64 (representing SPM addresses) + addConversion([](TensorType type) -> Type { + return IntegerType::get(type.getContext(), 64, IntegerType::Unsigned); + }); + + // Keep other types as is + addConversion([](Type type) -> Type { return type; }); + } + +private: + MLIRContext *context; +}; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Get format code for tensor element type +// This maps MLIR types to Tx81 format codes +Data_Format getFormatCode(Type type) { + if (type.isF32()) { + return Fmt_FP32; + } else if (type.isF16()) { + return Fmt_FP16; + } else if (type.isBF16()) { + return Fmt_BF16; + } else if (type.isInteger(8)) { + return Fmt_INT8; + } + + // Default to F32 format + return Fmt_FP32; +} + +// Get element count from shape +int32_t getElementCount(ArrayRef shape) { + int32_t elementCount = 1; + for (auto dim : shape) { + if (dim > 0) { + elementCount *= dim; + } + } + return elementCount; +} + +// Helper function to extract shape from tensor type +SmallVector getShapeFromTensorType(TensorType type) { + SmallVector shape; + for (auto dim : type.getShape()) + shape.push_back(static_cast(dim)); + return shape; +} + +// Helper function to extract dimensions from memref or tensor type +SmallVector getDimsFromType(Type type) { + SmallVector dims; + if (auto memrefType = dyn_cast(type)) { + for (auto dim : memrefType.getShape()) + dims.push_back(static_cast(dim)); + } else if (auto tensorType = dyn_cast(type)) { + for (auto dim : tensorType.getShape()) + dims.push_back(static_cast(dim)); + } + return dims; +} + +Value createAddressFromMemref(ConversionPatternRewriter &rewriter, Location loc, + Value memref) { + auto stridedMetadata = + rewriter.create(loc, memref); + Value indexBasePtr = rewriter.create( + loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + Value offset = stridedMetadata.getOffset(); + Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), + indexBasePtr, offset); + Value i64SPMPtr = rewriter.create( + loc, rewriter.getI64Type(), offsetPtr); + return i64SPMPtr; +} + +static std::tuple +createMetadata(ConversionPatternRewriter &rewriter, Location loc, + Value operand) { + auto stridedMetadata = + rewriter.create(loc, operand); + Value indexBasePtr = rewriter.create( + loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + Value offset = stridedMetadata.getOffset(); + Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), + indexBasePtr, offset); + Value i64SPMPtr = rewriter.create( + loc, rewriter.getI64Type(), offsetPtr); + + // FIXME: For multi-dimensional(rank > 2), strides need to be multiplied. + return {i64SPMPtr, stridedMetadata.getSizes(), stridedMetadata.getStrides()}; +} + +static SmallVector padSizesToNHWC(ConversionPatternRewriter &rewriter, + Location loc, ValueRange sizes) { + Value one = rewriter.create(loc, 1); + int numPad = 4 - sizes.size(); + SmallVector nhwcShape; + while (numPad--) { + nhwcShape.push_back(one); + } + for (auto dim : sizes) { + nhwcShape.push_back(dim); + } + return nhwcShape; +} + +// The last stride is always 1, skip it, nhwcStrides.size() will be 3. +static SmallVector padStridesToNHWC(ConversionPatternRewriter &rewriter, + Location loc, ValueRange strides) { + Value one = rewriter.create(loc, 1); + int numPad = 4 - strides.size(); + SmallVector nhwcStrides; + while (numPad--) { + nhwcStrides.push_back(one); + } + for (auto dim : strides) { + nhwcStrides.push_back(dim); + } + nhwcStrides.pop_back(); + return nhwcStrides; +} + +static Value calculateElemCount(ConversionPatternRewriter &rewriter, + Location loc, ValueRange sizes) { + Value elemCount = sizes[0]; + for (int i = 1; i < sizes.size(); i++) { + elemCount = rewriter.create(loc, elemCount.getType(), + elemCount, sizes[i]); + } + return elemCount; +} + +// Extract the operations from a linalg op region +template llvm::SmallVector getRegionOps(T linalgOp) { + auto regionBlock = linalgOp.getBody(); + return llvm::map_to_vector(regionBlock->without_terminator(), + [](Operation &op) { return &op; }); +} + +class MemoryCopyConvertPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + // Workaround: Avoid analyzing control flow as much as possible。 + bool isOperandMemorySpaceSPM(Value operand) const { + + while (auto op = operand.getDefiningOp()) { + if (isa(op)) + return true; + operand = op->getOperand(0); + } + return false; + } + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool isSrcSPM = isOperandMemorySpaceSPM(adaptor.getSource()); + bool isDstSPM = isOperandMemorySpaceSPM(adaptor.getTarget()); + + // DDR to DDR + if (!isSrcSPM && !isDstSPM) + return rewriter.notifyMatchFailure( + op, "Can not copy memory from DDR to DDR.\n"); + + auto [srcPtr, srcSizes, srcStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getSource()); + auto [dstPtr, dstSizes, dstStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getTarget()); + + auto inputType = dyn_cast(op.getSource().getType()); + // SPM to SPM + if (isSrcSPM && isDstSPM) { + // FIXME: Only support 1d for now, take sizes[0] as elemCount. + auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); + + // WORKAROUND: Assume no mask. + auto constValue = rewriter.create( + op.getLoc(), 0, rewriter.getI32Type()); + + rewriter.create( + op->getLoc(), rewriter.getI64Type(), srcPtr, constValue, dstPtr, + elemCount, // Element count + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode( + inputType)) // Format (5 = f32, assuming f32 for now) + ); + } else if (isDstSPM) { + auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), srcSizes); + auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), srcStrides); + + auto rdmaOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode( + inputType)) // Format (5 = f32, assuming f32 for now) + ); + } else { + auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), dstSizes); + auto nhwcStrides = padStridesToNHWC(rewriter, op->getLoc(), dstSizes); + + auto wdmaOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode( + inputType)) // Format (5 = f32, assuming f32 for now) + ); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +// Convert linalg.fill to MemsetOp +class LinalgFillOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(linalg::FillOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the value to fill with + Value fillValue = op.getInputs()[0]; // adaptor.getValue(); + + if (op.getOutputs().size() != 1) + return rewriter.notifyMatchFailure(op, "Only support single output\n"); + + // Convert the fill value to int64 + if (fillValue.getType().isF32()) { + // If it's a float constant, bitcast it to int + fillValue = rewriter.create( + op.getLoc(), rewriter.getI32Type(), fillValue); + } else if (fillValue.getType().isF16()) { + auto extf = rewriter.create( + op.getLoc(), rewriter.getF32Type(), fillValue); + fillValue = rewriter.create( + op.getLoc(), rewriter.getI32Type(), extf); + } + + auto [srcPtr, srcSizes, srcStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); + + // Create a MemsetOp to fill the SPM buffer + // TODO: Support format code for different element types + auto memsetOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, fillValue, elemCount, + rewriter.getI32ArrayAttr({}), // Strides (empty for simple fill) + rewriter.getI32ArrayAttr({}), // Iterations (empty for simple fill) + rewriter.getI16IntegerAttr(5) // Format (5 = f32, assuming f32 for now) + ); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// mk.dot to tx.gemm Conversion Pattern +//===----------------------------------------------------------------------===// + +class MKDotToTx81GemmOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::mk::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Extract dimensions from tensor types + MemRefType aTensorType = cast(op.getA().getType()); + MemRefType bTensorType = cast(op.getB().getType()); + + // Get converted operands + auto loc = op.getLoc(); + + auto aShape = aTensorType.getShape(); + auto bShape = bTensorType.getShape(); + + // Matrix dimensions M, K, N for GEMM + int32_t M = aShape[0]; + int32_t K = aShape[1]; + int32_t N = bShape[1]; + + // Create dimensions array attribute [M, K, N] + auto dims = rewriter.getI32ArrayAttr({M, K, N}); + + // Get operand ptr + auto a = createAddressFromMemref(rewriter, loc, adaptor.getA()); + auto b = createAddressFromMemref(rewriter, loc, adaptor.getB()); + auto c = createAddressFromMemref(rewriter, loc, adaptor.getC()); + auto zeros = createAddressFromMemref(rewriter, loc, adaptor.getZeroes()); + + // Create GemmOp + rewriter.create( + op.getLoc(), rewriter.getI64Type(), + a, // src_a (Matrix A in SPM) + b, // src_b (Matrix B in SPM) + c, // src_bias (optional accumulation) + zeros, // zeroes, + dims, // dimensions [M,K,N] + rewriter.getBoolAttr(false), // en_psum + zeros, // WORKAROUND: psum_addr (using zeroes buffer) + rewriter.getBoolAttr(false), // trans_src_a + rewriter.getBoolAttr(false), // trans_src_b + rewriter.getI32IntegerAttr(1), // batch_src_a + rewriter.getI32IntegerAttr(1), // batch_src_b + rewriter.getBoolAttr(false), // en_leaky_relu + rewriter.getBoolAttr(op.getC() != nullptr), // en_bias + rewriter.getBoolAttr(false), // en_neg_scale + rewriter + .create(op.getLoc(), 0, rewriter.getI64Type()) + .getResult(), // src_neg_scale + rewriter.getBoolAttr(false), // en_pos_scale + rewriter + .create(op.getLoc(), 0, rewriter.getI64Type()) + .getResult(), // src_pos_scale + rewriter.getI32IntegerAttr(3), // src_fmt (3 = f32) + rewriter.getI32IntegerAttr(3) // dst_fmt (3 = f32) + ); + // Op has no result value + rewriter.eraseOp(op); + + return success(); + } +}; + +struct ElementwiseConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + template + LogicalResult convertBinaryOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + // Create the elementwise operation + // TODO: Fix attribute + rewriter.create( + loc, rewriter.getI64Type(), input0, input1, output, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr( + getFormatCode(inputType)) // Format (5 = f32, assuming f32 for now) + ); + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult NormalConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + rewriter.create(loc, rewriter.getI64Type(), input, output, + elemCount); + rewriter.eraseOp(op); + return success(); + } + + template + LogicalResult RoundConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + // TODO: Fix attribute + auto result = + rewriter.create(loc, + rewriter.getI64Type(), // Result type + input, // Input + output, // Output + elemCount, // Element count + rewriter.getI16IntegerAttr(0) // Round mode + ); + rewriter.eraseOp(op); + return success(); + } + + LogicalResult + matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto regionOps = getRegionOps(op); + + // Check if the operation is elementwise + if (op.getIteratorTypesArray().front() != utils::IteratorType::parallel) + return rewriter.notifyMatchFailure(op, "Only support elementwise op."); + + auto elemWiseOp = regionOps[0]; + return llvm::TypeSwitch(elemWiseOp) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case( + [&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return NormalConvertOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + // TODO: Need add more int to fp convert. + auto inputType = + cast(op.getInputs()[0].getType()).getElementType(); + auto outputType = + cast(op.getOutputs()[0].getType()).getElementType(); + if (inputType.isInteger(16) && outputType.isF32()) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(16) && outputType.isF16()) { + return NormalConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(32) && outputType.isF16()) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isInteger(32) && outputType.isF32()) { + return RoundConvertOp(op, adaptor, rewriter); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for integer to " + "FP conversion"); + } + }) + .Case([&](auto elemWiseOp) { + // TODO: Need add more int to fp convert. + auto inputType = + cast(op.getInputs()[0].getType()).getElementType(); + auto outputType = + cast(op.getOutputs()[0].getType()).getElementType(); + if (inputType.isF16() && outputType.isInteger(8)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF16() && outputType.isInteger(16)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF16() && outputType.isInteger(32)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(8)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(16)) { + return RoundConvertOp(op, adaptor, rewriter); + } else if (inputType.isF32() && outputType.isInteger(32)) { + return RoundConvertOp(op, adaptor, rewriter); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for fp to " + "integer conversion"); + } + }) + .Default([&](auto elemWiseOp) { + // WORKAROUND: Used to handle tl.arange(0, BLOCK_SIZE) which will + // lower to linalg.generic + linalg.index + arith.index_cast and + // other unsupported case now (eg: arith::extf) + // TODO: Lower ops to tx81 if is supported + if (failed(linalg::linalgOpToAffineLoops(rewriter, op))) + return rewriter.notifyMatchFailure( + op, "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + }); + } +}; + +struct ReduceConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + bool isReductionOpSupported(Operation *redOp) const { + return isa( + redOp); + } + + template + LogicalResult convertToReduceOp(linalg::ReduceOp op, + typename linalg::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto dims = op.getDimensions(); + if (dims.size() != 1) + return rewriter.notifyMatchFailure(op, "Only support one dim reduce."); + auto dim = dims[0]; + auto input = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInputs()[0]); + auto output = + createAddressFromMemref(rewriter, op->getLoc(), adaptor.getInits()[0]); + auto inputType = dyn_cast(op.getInputs()[0].getType()); + auto inputShape = inputType.getShape(); + // TODO: Support any rank + if (inputShape.size() > 1) + return rewriter.notifyMatchFailure(op, "Rank > 1 unsupported yet."); + + if (dim && dim >= inputShape.size()) + return rewriter.notifyMatchFailure(op, + "Dimensions attribute > input rank !"); + + int64_t inputSize = inputShape.empty() ? 1 : inputShape[0]; + + SmallVector reduceShape = {1, 1, 1, inputSize}; + auto format = getFormatCode(inputType); + auto reduceOp = rewriter.create( + op->getLoc(), TypeRange{}, input, output, + rewriter.getUI32IntegerAttr(dim), rewriter.getI64ArrayAttr(reduceShape), + rewriter.getI16IntegerAttr(format)); + rewriter.replaceOp(op, reduceOp); + return success(); + } + +public: + LogicalResult + matchAndRewrite(linalg::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto reductionOps = getRegionOps(op); + + if (reductionOps.size() != 1 || + !isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with body " + "containing 1 max(i/f) or addf."); + } + auto redOp = reductionOps[0]; + + return llvm::TypeSwitch(redOp) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Case([&](auto redOp) { + return convertToReduceOp(op, adaptor, rewriter); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return failure(); + }); + } +}; + +} // namespace + +void mlir::triton::populateMKToTx81CanonicalizationPatterns( + RewritePatternSet &patterns) {} + +void mlir::triton::populateMKToTx81ConversionPatterns( + RewritePatternSet &patterns) { + + MKToTx81TypeConverter typeConverter; + + // Add type conversion patterns + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + // clang-format off + patterns.add( + patterns.getContext()); + // clang-format on +} diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp new file mode 100644 index 000000000..611ba7be7 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp @@ -0,0 +1,75 @@ +//===--------------------- MKToTx81Pass.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "mk-to-tx81" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MKTOTX81 +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class MKToTx81Pass : public triton::impl::MKToTx81Base { + using MKToTx81Base::MKToTx81Base; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + // Register illegal ops for Dialect Conversion + target.addIllegalDialect< linalg::LinalgDialect, + bufferization::BufferizationDialect, mk::MagicKernelDialect>(); + + target.addLegalDialect(); + + target.addIllegalOp(); + target.addLegalOp(); + + triton::populateMKToTx81ConversionPatterns(patterns); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> triton::createMKToTx81Pass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt new file mode 100644 index 000000000..0883bca4f --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(StructuredToMemref + StructuredToMemref.cpp + StructuredToMemrefPass.cpp + + DEPENDS + StructuredToMemrefConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRSCFTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp new file mode 100644 index 000000000..b5e1165a7 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -0,0 +1,859 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR//MemRef.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include +#include +#include + +#define DEBUG_TYPE "structured-to-memref" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +static const std::string WRAP_SIDE_BY_SIDE = "wrap_side_by_side"; +static const std::string WRAP_STACKED = "wrap_stacked"; + +static memref::SubViewOp getSubview(int rank, ArrayRef dims, + Value source, Location loc, OpBuilder &b) { + auto sourceType = cast(source.getType()); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector strides(rank, b.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + + return b.create(loc, cast(dstType), source, + offsets, dims, strides); +} + +namespace { + +struct MakeTensorPtrConverter + : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + static Type getElementTypeStructuredPtr(tts::MakeTensorPtrOp op) { + assert(!op.isBlockPtr()); + // tensor<1024x!tt.ptr> + auto ptrType = cast( + cast(op.getType()).getElementType()); + return ptrType.getPointeeType(); + } + + static Type getElementTypeBlockPtr(tts::MakeTensorPtrOp op) { + assert(op.isBlockPtr()); + // !tt.ptr, 1> + auto shapedType = cast( + cast(op.getType()).getPointeeType()); + return shapedType.getElementType(); + } + + static MemRefType getResultMemrefType(tts::MakeTensorPtrOp op, int64_t offset, + ArrayRef staticStrides, + ArrayRef resultShape) { + auto layout = + StridedLayoutAttr::get(op.getContext(), offset, staticStrides); + Type elemType; + if (op.isBlockPtr()) { + elemType = getElementTypeBlockPtr(op); + } else { + elemType = getElementTypeStructuredPtr(op); + } + return MemRefType::get(resultShape, elemType, layout); + } + + // If there are dimensions with size 1 and stride 0, replace 0 stride with + // the product of sizes of all lower dimensions. This avoids creating memref + // with zero stride. + static llvm::SmallVector + getMixedStridesForMemref(tts::MakeTensorPtrOp op, OpBuilder &b) { + llvm::SmallVector strides; + auto accumulate = 1; + for (auto [size, stride] : + llvm::reverse(llvm::zip(op.getSizes(), op.getMixedStrides()))) { + auto strideIntAttr = getIntAttr(stride); + if (size == 1 && strideIntAttr && strideIntAttr.value() == 0) { + strides.push_back(b.getIndexAttr(accumulate)); + } else { + strides.push_back(stride); + } + accumulate *= size; + } + std::reverse(strides.begin(), strides.end()); + return strides; + } + + static OpFoldResult accumulateTargetOffset(tts::MakeTensorPtrOp op, + OpBuilder &b) { + Location loc = op->getLoc(); + OpFoldResult targetOffset = b.getIndexAttr(0); + for (auto o : op.getMixedOffsets()) { + targetOffset = addOFRs(targetOffset, o, loc, b); + } + return targetOffset; + } + + std::pair + createSideBySideCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto resultShape = cast(op.getType()).getShape(); + + auto targetOffset = + ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Note: We do not support cases where the target has already overflown the + // number of columns! This is because in PtrAnalysis, the offset has already + // been collapsed into a single dimension, so it is ambiguous to determine + // whether the offset actually overflows or just refers to an element on the + // subsequent rows. + // + // Same limitations apply to the stacked wraparound case. + // + //////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // nextOffset = x + colSize + // clampedOffset = min(nextOffset, N) + // d1 = clampedOffset - x + // + //////////////////////////////////////////////////////////////////////////// + + auto resultType = getResultMemrefType( + op, /* offset */ ShapedType::kDynamic, + /* staticStrides */ + SmallVector(resultShape.size(), ShapedType::kDynamic), + /* result shape */ + SmallVector{ + + // Row stays the same + resultShape[0], + + // Column is dynamic, in most cases, this + // should be the same as the original column. + // The last chunk may be smaller due to + // wrapping around. + ShapedType::kDynamic}); + + Value rowSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[1])); + + Value modN = ofrToIndexValue(op.getMixedShape()[1], loc, rewriter); + + Value x = rewriter.create(loc, targetOffset, modN); + Value y = rewriter.create(loc, targetOffset, x); + + SmallVector strideVals = + ofrsToIndexValues(op.getMixedStrides(), loc, rewriter); + + // First chunk + Value nextOffset = rewriter.create(loc, x, colSize); + Value clampedOffset = + rewriter.create(loc, nextOffset, modN); + Value d1 = rewriter.create(loc, clampedOffset, x); + SmallVector sizes1{rowSize, d1}; + + auto cast1 = rewriter.create( + loc, resultType, adaptor.getBase(), targetOffset, sizes1, strideVals); + + // Second chunk + Value d2 = rewriter.create(loc, colSize, d1); + SmallVector sizes2{rowSize, d2}; + + auto cast2 = rewriter.create( + loc, resultType, adaptor.getBase(), y, sizes2, strideVals); + + return {cast1, cast2}; + } + + std::pair + createStackedCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto loc = op->getLoc(); + auto resultShape = cast(op.getType()).getShape(); + + assert(resultShape.size() == 2); + + auto targetOffset = + ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // + // We do not support cases where the target offset has already overflown the + // number of rows. See side-by-side wraparound for details. + // + //////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // ~~~~~~~~~~~~~~~~~ + // ^ + // | + // We have already computed + // rows * strideRows = modRow = shape[1] + // in TritonToStructured + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + + auto resultType = getResultMemrefType( + op, /* offset */ ShapedType::kDynamic, + /* staticStrides */ + SmallVector(resultShape.size(), ShapedType::kDynamic), + /* result shape */ + SmallVector{ + // Row is dynamic, in most cases, this should + // be the same as the original row. The last + // chunk may be smaller due to wrapping + // around. + ShapedType::kDynamic, + + // Col stays the same. + resultShape[1], + }); + + Value rowSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = rewriter.create( + loc, rewriter.getIndexAttr(op.getSizes()[1])); + + Value strideRow = ofrToIndexValue(op.getMixedStrides()[0], loc, rewriter); + Value strideCol = ofrToIndexValue(op.getMixedStrides()[1], loc, rewriter); + + Value modRow = op.getShape()[0]; + + // First chunk + Value wrappedAroundOff = + rewriter.create(loc, targetOffset, strideRow); + Value clampedOff = + rewriter.create(loc, modRow, wrappedAroundOff); + Value d1 = rewriter.create(loc, clampedOff, targetOffset); + d1 = rewriter.create(loc, d1, strideRow); + + SmallVector sizes1{d1, colSize}; + memref::ReinterpretCastOp cast1 = + rewriter.create( + loc, resultType, adaptor.getBase(), targetOffset, sizes1, + ValueRange{strideRow, strideCol}); + + // Second chunk + Value d2 = rewriter.create(loc, rowSize, d1); + SmallVector sizes2{d2, colSize}; + memref::ReinterpretCastOp cast2 = + rewriter.create( + loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); + + return {cast1, cast2}; + } + + LogicalResult rewriteSplitPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto parentShape = op.getStaticShape(); + + SmallVector casts; + StringRef wrapType; + + if (parentShape[0] == ShapedType::kDynamic) { + // Stacked case + assert(parentShape[1] == 0); + auto [cast1, cast2] = createStackedCastOps(op, adaptor, rewriter); + casts = {cast1.getResult(), cast2.getResult()}; + wrapType = WRAP_STACKED; + } else { + assert(parentShape[0] == 0); + auto [cast1, cast2] = createSideBySideCastOps(op, adaptor, rewriter); + casts = {cast1.getResult(), cast2.getResult()}; + wrapType = WRAP_SIDE_BY_SIDE; + } + + auto combinedCast = rewriter.create( + op.getLoc(), op.getType(), casts); + + combinedCast->setAttr(wrapType, rewriter.getUnitAttr()); + + rewriter.replaceOp(op, combinedCast); + + return success(); + } + + LogicalResult rewritePtr(ArrayRef resultShape, bool isBlockPtr, + tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto mixedStrides = getMixedStridesForMemref(op, rewriter); + SmallVector staticStrides; + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides); + + auto targetOffset = accumulateTargetOffset(op, rewriter); + auto staticTargetOffset = getIntAttr(targetOffset); + auto resultType = getResultMemrefType( + op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides, + resultShape); + + // The base ptr, which is from one of the args, would have already been + // converted to memref<*> at this point, so get the base from adaptor. + // + // For block pointers, the base could come from a sequence of `tt.addptr`, + // which at this point has already been lowered to a sequence of + // `memref.reinterpret_cast` ops. The offset in such cases are dynamic. + // (see test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir) + // + // For non-block pointer cases, the base is the reinterpret_cast of a + // function argument. Assert that the offset is a constant 0 in such cases. + auto ptr = adaptor.getBase(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto offset = reinterpretCast.getMixedOffsets()[0]; + auto intAttr = getIntAttr(offset); + assert(isBlockPtr || (intAttr.has_value() && intAttr.value() == 0)); + targetOffset = addOFRs(targetOffset, reinterpretCast.getMixedOffsets()[0], + op->getLoc(), rewriter); + } + + auto castOp = rewriter.create( + op.getLoc(), resultType, ptr, targetOffset, op.getMixedSizes(), + mixedStrides); + + rewriter.replaceOp(op, castOp); + + return success(); + } + + LogicalResult + rewriteStructuredPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ArrayRef resultShape = cast(op.getType()).getShape(); + return rewritePtr(resultShape, false, op, adaptor, rewriter); + } + + LogicalResult rewriteBlockPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Block pointers are basically the same as structured pointers except that + // the return types are !tt.ptr> instead of + // tensor> + ArrayRef resultShape = + cast( + cast(op.getType()).getPointeeType()) + .getShape(); + return rewritePtr(resultShape, true, op, adaptor, rewriter); + } + +public: + LogicalResult + matchAndRewrite(tts::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!llvm::is_sorted(op.getOrder(), std::greater<>())) { + emitError(op.getLoc()) << "non-decreasing dimension order on tensor " + "pointers are not yet supported"; + return failure(); + } + + if (op.isBlockPtr()) { + return rewriteBlockPtr(op, adaptor, rewriter); + } + + if (op.isStructuredPtr()) { + return rewriteStructuredPtr(op, adaptor, rewriter); + } + + if (op.isSplitPtr()) { + return rewriteSplitPtr(op, adaptor, rewriter); + } + + return failure(); + } +}; + +struct LoadConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + void createSideBySideCopies(Value block1, Value block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + void createStackedCopies(Value block1, Value block2, Value dst, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); + } + + memref::SubViewOp createSubview(Value src, ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, Location loc, + ConversionPatternRewriter &rewriter) const { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, cast(dstType), + src, offsets, sizes, strides); + } + + std::pair + getSideBySideSubviews(ArrayRef dims, Value block1, Value block2, + Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = + rewriter.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); + OpFoldResult subviewCol2 = + subOFRs(subviewColFull, subviewCol1, loc, rewriter); + + SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); + SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, offsets, {subviewRowFull, subviewCol1}, + strides, loc, rewriter); + auto sv2 = createSubview(block2, offsets, {subviewRowFull, subviewCol2}, + strides, loc, rewriter); + + return {sv1, sv2}; + } + + std::pair + getStackedSubviews(ArrayRef dims, Value block1, Value block2, + const Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = + rewriter.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); + OpFoldResult subviewRow2 = + subOFRs(subviewRowFull, subviewRow1, loc, rewriter); + + SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); + SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, offsets, {subviewRow1, subviewColFull}, + strides, loc, rewriter); + auto sv2 = createSubview(block2, offsets, {subviewRow2, subviewColFull}, + strides, loc, rewriter); + return {sv1, sv2}; + } + + LogicalResult + rewriteStructuredLoad(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(!op.hasMask()); + + auto loc = op->getLoc(); + auto ptr = adaptor.getPtr(); + auto other = op.getOther(); + + auto tensorType = cast(op.getType()); + auto elemType = tensorType.getElementType(); + + auto alloc = rewriter.create( + loc, MemRefType::get(tensorType.getShape(), elemType)); + + // No mask + assert(!other && "other value used in non-masked load"); + + if (auto unrealizedCast = ptr.getDefiningOp()) { + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { + createSideBySideCopies(block1, block2, alloc, loc, rewriter); + } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { + createStackedCopies(block1, block2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + } else { + rewriter.create(loc, ptr, alloc); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + + LogicalResult rewriteMaskedLoad(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(op.hasMask()); + + auto loc = op->getLoc(); + auto ptr = adaptor.getPtr(); + + auto tensorType = cast(op.getType()); + auto elemType = tensorType.getElementType(); + + auto alloc = rewriter.create( + loc, MemRefType::get(tensorType.getShape(), elemType)); + + SmallVector mixedDims = op.getMixedMaskDims(); + + // Fill load destination with other value + if (op.getOther()) { + // For each dimension check if dims[i] < shape[i], or-accumulate + // the result + auto shape = tensorType.getShape(); + auto accBase = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + for (size_t i = 0; i < shape.size(); i++) { + auto shapei = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + + Value dimi = dyn_cast(mixedDims[i]); + if (!dimi) { + dimi = rewriter.create( + loc, rewriter.getIndexAttr(op.getStaticMaskDims()[i])); + } + + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = rewriter.create(loc, accBase, cmp); + } + + // condition the memset on the or-accumulation + // initialize with padding prior to CopyOp + rewriter.create(loc, accBase, [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{op.getOther()}, + ValueRange{alloc}); + b.create(loc); + }); + } + + if (auto unrealizedCast = ptr.getDefiningOp()) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; + + if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { + auto [subview1, subview2] = + getSideBySideSubviews(mixedDims, block1, block2, loc, rewriter); + createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); + } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { + auto [subview1, subview2] = + getStackedSubviews(mixedDims, block1, block2, loc, rewriter); + createStackedCopies(subview1, subview2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + + rewriter.eraseOp(unrealizedCast); + + } else { + memref::SubViewOp srcSubview = + getSubview(tensorType.getRank(), mixedDims, ptr, loc, rewriter); + memref::SubViewOp dstSubview = + getSubview(tensorType.getRank(), mixedDims, alloc, loc, rewriter); + rewriter.create(loc, srcSubview, dstSubview); + } + + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + + return success(); + } + +public: + LogicalResult + matchAndRewrite(tts::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.hasMask()) { + return rewriteMaskedLoad(op, adaptor, rewriter); + } else { + return rewriteStructuredLoad(op, adaptor, rewriter); + } + } +}; + +struct StoreConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + + static tensor::ExtractSliceOp + getExtractSlice(int rank, ArrayRef dims, Value source, + const Location loc, OpBuilder &b) { + auto sourceType = cast(source.getType()); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector strides(rank, b.getIndexAttr(1)); + + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, + dims, strides); + + return b.create(loc, dstType, source, offsets, dims, + strides); + } + +public: + LogicalResult + matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptr = adaptor.getPtr(); + auto storeValue = op.getValue(); + auto rank = cast(storeValue.getType()).getRank(); + + if (op.hasMask()) { + auto mixedDims = op.getMixedMaskDims(); + + auto srcSlice = + getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); + auto dstSubview = getSubview(rank, mixedDims, ptr, loc, rewriter); + + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + } else { + auto storeOp = rewriter.create( + loc, storeValue, ptr); + storeOp.setWritable(true); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ScalarLoadConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = op->getLoc(); + auto memrefPtr = adaptor.getPtr(); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + auto loadOp = rewriter.create(loc, memrefPtr, zeroMap, + std::nullopt); + rewriter.replaceOp(op, loadOp.getResult()); + + return success(); + } +}; + +struct ScalarStoreConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getValue().getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = op->getLoc(); + auto memrefPtr = adaptor.getPtr(); + auto val = op.getValue(); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + + rewriter.create(loc, val, memrefPtr, zeroMap, + std::nullopt); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct UnrealizedCastConverter + : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; + +public: + UnrealizedCastConverter(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, + context) {} + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resType = op->getResultTypes()[0]; + auto input = op.getInputs()[0]; + auto inputType = input.getType(); + + if (!isa(resType) || + !isa(inputType)) { + return failure(); + } + + if (auto reinterpretCast = + input.getDefiningOp()) { + rewriter.replaceOp(op, reinterpretCast); + } else { + auto ptrType = cast(resType); + auto memrefType = + cast(getTypeConverter()->convertType(ptrType)); + + auto cast = rewriter.create( + op->getLoc(), memrefType, op.getInputs()[0], 0 /*offset*/, + SmallVector{1} /*sizes*/, + SmallVector{1} /*strides*/); + + rewriter.replaceOp(op, cast); + } + + return success(); + } +}; + +} // namespace + +void mlir::triton::populateStructuredToMemrefConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter) { + patterns.add(typeConverter, patterns.getContext()); + patterns.add( + patterns.getContext()); +} diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp new file mode 100644 index 000000000..8bf53cf61 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -0,0 +1,416 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Transforms/OneToNTypeConversion.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +#include +#include + +#define DEBUG_TYPE "structured-to-memref" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_STRUCTUREDTOMEMREF +#include "triton-shared/Conversion/StructuredToMemref/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +static MemRefType getMemrefTypeForScalarPtr(triton::PointerType ptrType, + MLIRContext *context) { + SmallVector strides{1}; + auto layout = StridedLayoutAttr::get(context, ShapedType::kDynamic, strides); + + auto elemType = ptrType.getPointeeType(); + auto memrefType = MemRefType::get({1}, elemType, layout); + return memrefType; +} + +class TritonFunctionSignatureConverter : public TypeConverter { +public: + TritonFunctionSignatureConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + // Used for converting memref<*> back to tt.ptr type, these ops will then be + // handled when we convert addptr op later. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + + addArgumentMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + } +}; + +class LoopTypeConverter : public TypeConverter { +public: + LoopTypeConverter(MLIRContext *context) { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([context](triton::PointerType ptrType) { + return getMemrefTypeForScalarPtr(ptrType, context); + }); + + // A tensor of pointers can be passed in as scf.for's init-args, in such + // cases, we convert the type to a memref with dynamic offsets and + // strides. + addConversion( + [context](RankedTensorType tensorType) -> std::optional { + if (auto ptrType = llvm::dyn_cast( + tensorType.getElementType())) { + auto layout = StridedLayoutAttr::get( + context, ShapedType::kDynamic, + SmallVector(tensorType.getRank(), + ShapedType::kDynamic)); + Type elemType = ptrType.getPointeeType(); + return MemRefType::get(tensorType.getShape(), elemType, layout); + } + + return std::nullopt; + }); + + // Convert the current memref type to a memref type with dynamic offsets and + // strides through another reinterpret_cast with the same offsets. + // Canonicalization will simplify this sequence by removing the inital + // reinterpret_cast. + addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType, + ValueRange inputs, + Location loc) -> Value { + auto reinterpretCast = + inputs[0].getDefiningOp(); + return builder.create( + loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0], + reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides()); + }); + } +}; + +struct ScalarAddptrConverter + : public OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + if (isa(op.getType())) { + return failure(); + } + + auto loc = op->getLoc(); + + auto offsetIndex = rewriter.create( + loc, rewriter.getIndexType(), op.getOffset()); + + auto ptrInfo = adaptor.getPtr(); + assert(ptrInfo.size() == 2); + auto ptr = ptrInfo[0]; + auto offset = ptrInfo[1]; + + auto newOffset = rewriter.create(loc, offset, offsetIndex); + + auto castOp = rewriter.create( + loc, + getMemrefTypeForScalarPtr( + cast(op.getPtr().getType()), + rewriter.getContext()), + ptr, getAsOpFoldResult(newOffset) /*offset*/, + ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, + ArrayRef{rewriter.getIndexAttr(1)} /*strides*/); + + rewriter.replaceOp(op, SmallVector{castOp.getResult(), newOffset}, + adaptor.getResultMapping()); + + return success(); + } +}; + +static std::optional> +buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + assert(resultTypes.size() == 2 && isa(resultTypes[0]) && + isa(resultTypes[1]) && + "Unexpected result types when converting addptr"); + assert(isa(input.getType()) && + "Unexpected input type when converting addptr"); + + // There are only two types of ops that can produce a result of type tt.ptr + // 1) tt.addptr, this is already handled by ScalarAddptrConverter + // 2) unrealized_conversion_cast, which are inserted during the conversion + // of function arguments. + // We assert that there can only be input that comes from + // unrealized_conversion_cast. + auto castOp = input.getDefiningOp(); + assert(castOp && "Unexpected defining op for input of type tt.ptr"); + + // Compute the memref type + auto buffer = castOp.getOperand(0); + auto bufferType = cast(buffer.getType()); + auto layout = + StridedLayoutAttr::get(builder.getContext(), ShapedType::kDynamic, {1}); + auto memrefType = MemRefType::get({1}, bufferType.getElementType(), layout); + + // Create ops to convert the triton input type to a pair of {memref, index} + auto cast = builder.create( + loc, memrefType, buffer, 0 /*offset*/, ArrayRef{(1)} /*sizes*/, + ArrayRef{(1)} /*strides*/); + auto zero = builder.create(loc, builder.getIndexAttr(0)); + + return SmallVector{cast, zero}; +} + +static Value buildCastOp(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + assert(isa(resultType)); + assert(inputs.size() && isa(inputs[0].getType()) && + isa(inputs[1].getType())); + return builder.create(loc, resultType, inputs[0]) + .getResult(0); +} + +class StructuredToMemrefPass + : public triton::impl::StructuredToMemrefBase { + using StructuredToMemrefBase::StructuredToMemrefBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + LogicalResult convertArgsToMemrefType() { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonFunctionSignatureConverter typeConverter; + + // Update function signatures and calls to use memrefs + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + target.addDynamicallyLegalOp([&](func::CallOp op) { + return typeConverter.isLegal(op.getResultTypes()) && typeConverter.isLegal(op.getOperandTypes()); + }); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + return applyPartialConversion(moduleOp, target, std::move(patterns)); + } + + // We leverage the 1->N conversion infrastructure to convert tt.addptr for + // scalar to memref.reinterpret_cast. + // + // A tt.addptr has the following form: + // + // %new_ptr = tt.addptr %ptr %offset + // + // where %new_ptr and %ptr have tt.ptr type, and %offset is of index type. + // + // With this form, there can be a chain of tt.addptr where we keep adding + // offsets to an existing pointer: + // + // %ptr_1 = tt.addptr %arg0 %offset + // %ptr_2 = tt.addptr %ptr_1 %offset + // %ptr_3 = tt.addptr %ptr_2 %offset + // + // Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that + // the pointers can be used by affine.load and affine.store (lowered from + // tt.load and tt.store). + // + // A memref.reinterpret_cast op also takes an offset and returns a memref in a + // similar fashion to tt.addptr: + // + // %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: + // [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: + // ?>> + // + // However, since the semantic of memref.reinterpret_cast is different, + // the following lowering would be incorrect for the sequence of tt.addptr + // above: + // + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset] + // %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset] + // + // The above sequence is equivalent to: + // + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset] + // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset] + // + // In other word, memref.reinterpret_cast ignores the current offset of the + // input buffer. + // + // Therefore, we have to manually track the offset for each addptr by lowering + // to the following form: + // + // %offset_1 = arith.addi %cst_0 %offset + // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1] + // + // %offset_2 = arith.addi %offset_1 %offset + // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2] + // + // %offset_3 = arith.addi %offset_2 %offset + // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3] + // + // Each tt.addptr is lowered to a pair of arith.addi that accumulates the + // current offset before using that offset to the reinterpret_cast. + LogicalResult convertAddPtrToReinterpretCast() { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->2 type conversion here, where a triton pointer type + // maps to a pair of {memref, index} type for the the buffer and offset. + converter.addConversion( + [context](triton::PointerType ptrType, SmallVectorImpl &types) + -> std::optional { + types = SmallVector{getMemrefTypeForScalarPtr(ptrType, context), + IndexType::get(context)}; + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert a pair of {memref, + // index} type back to the original triton pointer type. + // These are used when there are ops that still need to use the original + // pointer type. For instance, we convert the result of tt.addptr from + // tt.ptr type to a pair of {memref, index}, but the original ptr result is + // still being used by another tt.load or tt.store. + converter.addArgumentMaterialization(buildCastOp); + converter.addSourceMaterialization(buildCastOp); + + // Compute the target materialization, given a value with the pointer type, + // convert that value to a pair of {memref, index} type. +#if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization(buildCastAndOffsetOps); +#endif + + patterns.add(converter, context); + + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + return failure(); + } + + return success(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + if (failed(convertArgsToMemrefType())) { + signalPassFailure(); + return; + } + + if (failed(convertAddPtrToReinterpretCast())) { + signalPassFailure(); + return; + } + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, ttx::TritonTilingExtDialect, + memref::MemRefDialect>(); + + target.addIllegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + auto resType = op->getResultTypes()[0]; + return !isa(resType); + }); + + LoopTypeConverter loopTypeConverter(patterns.getContext()); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + loopTypeConverter, patterns, target); + + triton::populateStructuredToMemrefConversionPatterns(patterns, + loopTypeConverter); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + // Erase dead code and fold constants created during lowering + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createStructuredToMemrefPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt new file mode 100644 index 000000000..f20b2102c --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonArithToLinalg + TritonArithToLinalg.cpp + TritonArithToLinalgPass.cpp + + DEPENDS + TritonArithToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRLinalgTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp new file mode 100644 index 000000000..cecd8e7b0 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include + +#define DEBUG_TYPE "triton-arith-to-linalg" +#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" + +void mlir::triton::populateTritonArithToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.add, MinMaxConverter>( + patterns.getContext()); +} + +void mlir::triton::populateTritonArithToLinalgConversionPatterns( + bool pidsToFuncArgs, bool addptrToLinalg, bool assertToCf, + RewritePatternSet &patterns) { + + if (pidsToFuncArgs) { + patterns.add( + patterns.getContext()); + } + if (addptrToLinalg) { + patterns.add(patterns.getContext()); + } + if (assertToCf) { + patterns.add(patterns.getContext()); + } + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + populateExternElementwiseOpToMLIROps(patterns); + + // Reduce converters + // Triton's reduce op is idential to linalg.reduce op, so we can clone + // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to + // perform pattern matching to know what reduce ops we are dealing with + // so that we know how to initialize the initial reduce values correctly. + // + // We can do this in a generic way without pattern matching by always using + // the first elements along the reduction axis and perform the reduction on + // the remaining elements. However, this results in creatings sub-tensors that + // aren't always multiple of 2s, which are sub-optimal for certain hardwares. + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + // Note: the ordering here matters! + // These patterns are added last to they will be tried last. + linalg::populateElementwiseToLinalgConversionPatterns(patterns); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp new file mode 100644 index 000000000..12e34c602 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -0,0 +1,227 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-arith-to-linalg" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_TRITONARITHTOLINALG +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class TritonArithToLinalgPass + : public triton::impl::TritonArithToLinalgBase { + using TritonArithToLinalgBase< + TritonArithToLinalgPass>::TritonArithToLinalgBase; + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + // Add additional I32 arguments to represent: + // - num_programs, 3 in total, one for each axis of the launch grid + // - program_id, 3 in total, one for each axis of the launch grid + static void addProgramInfo(triton::FuncOp func) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // Add empty attributes for each new argument if needed + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // Add the corresponding arguments to function body + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + } + + LogicalResult applyTensorConcatDecomposition() { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + tensor::populateDecomposeTensorConcatPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + return failure(); + } + return success(); + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + { + RewritePatternSet patterns(&getContext()); + populateTritonArithToLinalgCanonicalizationPatterns(patterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, ttx::TritonTilingExtDialect, + tts::TritonStructuredDialect, mk::MagicKernelDialect>(); + + target.addLegalOp(); + + target.addLegalOp(); + + target.addDynamicallyLegalDialect( + [](Operation *op) { + // Lower dense constant to linalg.fill + if (auto constOp = dyn_cast(op)) { + if (!isa(constOp.getResult().getType())) { + return true; + } + + if (auto denseAttr = + dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat() && + isa(denseAttr.getElementType())) { + return false; + } + } + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), [](Type type) { + return isa(type); + }); + + return !operateOnTensors; + }); + + if (pidsToFuncArgs) { + target.addIllegalOp(); + } + + if (addptrToLinalg) { + target.addDynamicallyLegalOp([](triton::AddPtrOp op) { + return !isa(op.getResult().getType()); + }); + } + + if (!assertToCf) { + target.addLegalOp(); + } + + triton::populateTritonArithToLinalgConversionPatterns( + pidsToFuncArgs, addptrToLinalg, assertToCf, patterns); + + if (pidsToFuncArgs) { + for (auto func : getOperation().getOps()) { + addProgramInfo(func); + } + } + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + if (failed(applyTensorConcatDecomposition())) { + signalPassFailure(); + } + + // Convert tt.func and tt.return into func's counterparts + if (ttToFuncFunc) { + moduleOp.walk([&](triton::FuncOp func) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + auto funcFunc = builder.create(func.getLoc(), name, type); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + // Only convert to func.return if the terminator is a tt.return. + // Otherwise, we will accidentally convert cf.br ops which are also + // considered terminators. + if (isa(term)) { + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + } + func.erase(); + }); + } + } +}; + +} // namespace + +std::unique_ptr> +triton::createTritonArithToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt new file mode 100644 index 000000000..d703f9150 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt @@ -0,0 +1,31 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +add_triton_library(TritonToCoreDialects + TritonToCoreDialectsPass.cpp + + DEPENDS + TritonToCoreDialectsConversionPassIncGen + + LINK_LIBS PUBLIC + TritonTilingExtIR + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonAnalysis + TritonIR + TritonTransforms + ZTCAnalysis + + TritonArithToLinalg + StructuredToMemref + TritonToStructured +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp new file mode 100644 index 000000000..8085836bf --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h.inc" + +namespace { + +class TritonToCoreDialectsPass + : public TritonToCoreDialectsBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createTritonToStructuredPass()); + + // Erase dead code and fold constants created during lowering + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addPass(createTritonArithToLinalgPass()); + pm.addPass(createStructuredToMemrefPass()); + + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToCoreDialectsPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..e9ebf49c3 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,28 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +add_triton_library(TritonToLinalg + TritonToLinalg.cpp + TritonToLinalgPass.cpp + + DEPENDS + TritonToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + TritonTilingExtIR + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonAnalysis + TritonIR + TritonTransforms + ZTCAnalysis +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp new file mode 100644 index 000000000..ea6d32593 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" + +#define DEBUG_TYPE "triton-to-linalg" +#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.add, MinMaxConverter>( + patterns.getContext()); +} + +void mlir::triton::populateTritonToLinalgConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + unsigned int launchGridRank) { + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + populateExternElementwiseOpToMLIROps(patterns); + + // Reduce converters + // Triton's reduce op is idential to linalg.reduce op, so we can clone + // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to + // perform pattern matching to know what reduce ops we are dealing with + // so that we know how to initialize the initial reduce values correctly. + // + // We can do this in a generic way without pattern matching by always using + // the first elements along the reduction axis and perform the reduction on + // the remaining elements. However, this results in creatings sub-tensors that + // aren't always multiple of 2s, which are sub-optimal for certain hardwares. + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + // Note: the ordering here matters! + // MetaOpConverter has PatternBenefit == 10 which should take precedence over + // these linalg patterns, but to be safe, add these patterns last so that they + // will be tried last. Incorrect ordering or having MetaOpConverter has lower + // PatternBenefit will result in element-wise meta ops being converted to + // linalg.generic ops. + linalg::populateElementwiseToLinalgConversionPatterns(patterns); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp new file mode 100644 index 000000000..25b7db85f --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp @@ -0,0 +1,229 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Analysis/UseAnalysis.h" +#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-to-linalg" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" + +namespace { + +class TritonTypeConverter : public TypeConverter { +public: + TritonTypeConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + addConversion([](TensorType tensorType) -> Type { + auto elemType = tensorType.getElementType(); + if (auto ptrType = dyn_cast(elemType)) { + elemType = ptrType.getPointeeType(); + } + return MemRefType::get(tensorType.getShape(), elemType); + }); + } +}; + +class TritonToLinalgPass : public TritonToLinalgBase { + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + // Add additional I32 arguments to represent: + // - num_programs, 3 in total, one for each axis of the launch grid + // - program_id, 3 in total, one for each axis of the launch grid + static void addProgramInfo(triton::FuncOp func) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // Add empty attributes for each new argument if needed + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // Add the corresponding arguments to function body + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + { + RewritePatternSet patterns(&getContext()); + populateTritonToLinalgCanonicalizationPatterns(patterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } + + moduleOp.walk([this](triton::FuncOp op) { + if (failed(runUseAnalysis(op))) { + signalPassFailure(); + } + }); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonTypeConverter tritonTypeConverter; + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, + ttx::TritonTilingExtDialect>(); + + target.addLegalOp(); + + // Update function signature to use memrefs + target.addDynamicallyLegalOp([&](triton::FuncOp op) { + return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); + }); + + // Lower dense constant to linalg.fill + target.addDynamicallyLegalOp([](arith::ConstantOp op) { + if (!isa(op.getResult().getType())) { + return true; + } + + if (auto denseAttr = dyn_cast(op.getValue())) { + if (denseAttr.isSplat() && + isa(denseAttr.getElementType())) { + return false; + } + } + return true; + }); + + target.addDynamicallyLegalOp([](Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type t) { + if (isa(t)) { + return false; + } + if (auto shapedType = dyn_cast(t)) { + return shapedType.getElementType().isIntOrFloat(); + } + assert(t.isIntOrIndexOrFloat()); + return true; + }); + }); + + target.addDynamicallyLegalDialect( + [](Operation *op) { + if (op->hasAttr("MetaUse")) { + return false; + } + + if (isa(op)) { + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), [](Type type) { + return isa(type); + }); + + return !operateOnTensors; + }); + + triton::populateTritonToLinalgConversionPatterns( + tritonTypeConverter, patterns, LAUNCH_GRID_RANK); + + for (auto func : getOperation().getOps()) + addProgramInfo(func); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + + // Convert tt.func and tt.return into func's counterparts + moduleOp.walk([&](triton::FuncOp func) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + auto funcFunc = builder.create(func.getLoc(), name, type); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + func.erase(); + }); + + // Erase dead code and fold constants created during lowering + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> triton::createTritonToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 000000000..743d09138 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonToStructured + TritonToStructuredPass.cpp + + DEPENDS + TritonStructuredTableGen + TritonToStructuredConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + MLIRReconcileUnrealizedCasts + TritonIR + TritonTransforms + ZTCAnalysisStructured + TritonStructuredIR +) diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp new file mode 100644 index 000000000..c479a06ef --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -0,0 +1,344 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include + +#define DEBUG_TYPE "triton-to-structured" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +namespace { + +class TritonToStructuredPass + : public TritonToStructuredBase { + + static TupleType getStructuredStateTupleType(MLIRContext *context, Type t) { + SmallVector tupleTypes{t}; + auto [offsetTypes, strideTypes] = + *tts::GetStructuredStateOp::getOffsetAndStrideTypes(context, t); + tupleTypes.append(offsetTypes); + tupleTypes.append(strideTypes); + return TupleType::get(context, tupleTypes); + } + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + LogicalResult convertToPointerTupleWithOffsetsAndStrides() { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->1 type conversion here, where a triton pointer type + // maps to a tuple of {pointer, offset_0, offset_1,..., stride_0, + // stride_1,...} type. + // + // Case 1: Unstructured pointers (tensor>) + converter.addConversion([context](RankedTensorType tensorType, + SmallVectorImpl &types) + -> std::optional { + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + if (!isa(tensorType.getElementType()) && + !tensorType.getElementType().isIntOrIndex()) { + // There's a subtle difference between returning failure() and + // std::nullopt. From the documentation: + // + // If std::nullopt is returned, the converter is allowed to try another + // conversion function to perform the conversion. + // + // Say we have type tensor<4x256xbf16> which is a RankedTensorType. Even + // though this RankedTensorType matches the converter that handles the + // tuple conversion, we want to keep this type as is because the inner + // type isn't a pointer. + // + // By returning failure(), the TypeConverters will stop trying the + // remaining converters. In our case, the last type converter which + // simply returns the same type is skipped. And because the conversion + // for this type has failed, the whole conversion process is also + // skipped. + // + // Relevant links to the implementation: + // + // https://github.com/llvm/llvm-project/blob/cb5dc1faa8b3702e0d03426ee5dfc5e1b903ec47/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2958 + // https://github.com/llvm/llvm-project/blob/cb5dc1faa8b3702e0d03426ee5dfc5e1b903ec47/mlir/lib/Transforms/Utils/DialectConversion.cpp#L3033 + return std::nullopt; + } + types = + SmallVector{getStructuredStateTupleType(context, tensorType)}; + return success(); + }); + + // Case 2: Block pointers (!tt.ptr> or !tt.ptr) + converter.addConversion([context](triton::PointerType ptrType, + SmallVectorImpl &types) + -> std::optional { + types = SmallVector{getStructuredStateTupleType(context, ptrType)}; + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert the tuple type back to + // the original triton pointer type. These are used when there are ops that + // still need to use the original pointer type. For instance, we convert the + // result of tt.addptr from tt.ptr type to a tuple, but the original ptr + // result is still being used by another tt.load or tt.store. + auto materialize = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + + converter.addArgumentMaterialization(materialize); + converter.addSourceMaterialization(materialize); + + // Compute the target materialization, given a value with the pointer type, + // convert that value to a tuple type. +#if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) -> std::optional> { + return builder + .create(loc, resultTypes, input) + ->getResults(); + }); +#endif + + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + return failure(); + } + + return success(); + } + + LogicalResult decomposePointerTuple() { + auto moduleOp = getOperation(); + + auto context = &getContext(); + TypeConverter converter; + converter.addConversion([](Type type) { return type; }); + + // We are doing a 1->N type conversion here, where a pointer tuple type + // maps to a sequence of {pointer, offset_0, offset_1,..., stride_0, + // stride_1,...} + converter.addConversion( + [context](TupleType tupleType, SmallVectorImpl &types) + -> std::optional { + tupleType.getFlattenedTypes(types); + return success(); + }); + + // Hooks to compute the correct materialization, "argument" and "source" + // materialization are used when we need to convert a series of {pointer, + // offset_0, offset_1,..., stride_0, stride_1,...} type back to the "pointer + // tuple type". + // + // Because we actually want to get rid of the tuple type, return `inputs[0]` + // which corresponds to a "triton pointer type". This approach will work as + // intended because the ops that currently take "pointer tuple type" are + // `unrealized_conversion_cast` ops which will get removed below during + // reconcile-unrealized-conversion-casts. + auto materialize = [](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) { return inputs[0]; }; + converter.addArgumentMaterialization(materialize); + converter.addSourceMaterialization(materialize); + + // For each value of "pointer tuple type" that gets decomposed into a + // sequence of {pointer, offset_0, offset_1,..., stride_0, stride_1,...}, + // create a `tts.get_structured_state` op that serves as a placeholder. + // The return values for this op will be used as the init-args for scf.for. + // At the end of pointer analysis, we will use the PtrState to create the + // correct offsets, strides, and remove these ops. + #if 0 // FIXME: Incompatible MILR interface + converter.addTargetMaterialization([](OpBuilder &builder, + TypeRange resultTypes, Value input, + Location loc) { + auto placeholder = builder.create( + loc, input.getDefiningOp()->getOperand(0)); + assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); + return placeholder.getResults(); + }); +#endif + + RewritePatternSet patterns(&getContext()); + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) { + return failure(); + } + + // Note: + // Be careful not to run canonicalization here, because the + // tts.get_structured_state ops created above are just placeholders and + // don't have any effects. Canonicalization will remove them altogether. + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + + return success(); + } + + // Prepass that inserts `tts.get_structured_state` ops. These ops are used as + // placeholders to make passing structured pointer state into scf.for loop's + // init args easier, especially with multiple levels of loops. + // + // Background: + // + // PtrAnalysis computes a PtrState for every operand (or triton value) + // involved in a sequence of pointer arithmetic; some examples include: triton + // pointer, offsets (which could be a tensor of indices or just a simple index + // value). + // + // If a triton value is updated and returned in a scf.for op, it means + // that we have to carry its offsets and strides in the scf.for's iterargs. + // + // Previously, we have to manually rewrite the loops to include the + // relevant information from a PtrState which was rather involved and + // error-prone; this was also hard to scale up to multiple level of loops + // because there are several book-keeping data structures that we have to + // maintain. + // + // With the introduction of the prepass that inserts + // `tts.get_structured_state`. The return values of these ops, which include a + // triton value with its original result type and its corresponding offsets + // and strides, will be used as "placeholders" into the scf.for's init-args. + // We leverage standard MLIR infrastructure 1->N conversion to perform this + // rewrite, which helps simplify the logic significantly. + // + // After PtrAnalysis finishes, the return values of these + // `tts.get_structured_state` ops will be remapped to the correct + // initialization of the value's offsets and strides through the value's + // computed PtrState. + // + // Implementation details: + // In essence, what we really want to do in the prepass is, for every value + // of triton-pointer-like type (tt.ptr or tensor>) and tensor of + // indices (tensor) which might be used in a sequence of pointer + // arithmetic, we want to create an op `tts.get_structured_state` that takes + // in the original triton value and returns a series of values: + // + // {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...} + // + // Applying the above conversion will also mean that any structural ops such + // as scf.for and scf.yield that originally takes the triton pointer will + // then take {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...}. + // + // The 1->N type conversion is a perfect fit for this transformation. + // Unfortunately, we cannot do this is one pass, because the current 1->N + // type conversion implementation for scf.for ops doesn't provide us with a + // way to detect that a type conversion is recursive. So a triton_value type + // that gets converted to a {triton_value, offset_0, offset_1, ..., stride_0, + // stride_1,...} will recursively trigger other conversions. + // + // To fix this issue, we have to first convert triton_value to + // tuple. + // Finally, we decompose these tuples into the desired sequence. + // + // Note that even though the type conversion happens for every integer tensor + // appearing in loops' iter-args, this conversion is reversible. If the + // integer tensor isn't used in a pointer arithmetic sequence, + // canonicalization will remove all the `tts.get_structured_state` ops and + // revert the IR back to its original form. + LogicalResult runTritonToStructuredPrepass() { + if (failed(convertToPointerTupleWithOffsetsAndStrides())) { + return failure(); + } + + return decomposePointerTuple(); + } + + void runOnOperation() override { + if (!skipPrepass && failed(runTritonToStructuredPrepass())) { + signalPassFailure(); + return; + } + + if (runPrepassOnly) { + return; + } + + auto moduleOp = getOperation(); + mlir::tts::PtrAnalysis ptrAnalysis; + ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); + + if (failed(ptrAnalysis.rewriteOp(moduleOp, useUnsafeMask))) { + moduleOp->emitWarning("PtrAnalysis failed"); + } + + // Now that all the PtrStates have been populated, we can wire up the states + // with the tts.get_structured_state ops inserted in the prepass. + moduleOp.walk([&ptrAnalysis](tts::GetStructuredStateOp op) { + if (failed(ptrAnalysis.rewriteGetStructuredStateOp(op))) { + op.emitWarning("Rewriting GetStructuredStateOp failed."); + } + }); + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToStructuredPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt new file mode 100644 index 000000000..9a20a8c10 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(Tx81MemrefToLLVM + Tx81MemrefToLLVM.cpp + Tx81MemrefToLLVMPass.cpp + + DEPENDS + Tx81MemrefToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp new file mode 100644 index 000000000..8a144b8a9 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp @@ -0,0 +1,335 @@ +//===------------------- Tx81MemrefToLLVM.cpp------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include +#include + +#define DEBUG_TYPE "tx81-memref-to-llvm" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" + +namespace { + +// Used to kcore load/store data from/to spm +const int64_t spmMappingOffset = 0x30400000; + +//===----------------------------------------------------------------------===// +// Tx81 Custom MemRef Op Conversion Patterns +//===----------------------------------------------------------------------===// + +struct TsmMemRefAllocOpLowering : public AllocLikeOpLLVMLowering { + TsmMemRefAllocOpLowering(const LLVMTypeConverter &converter) + : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), + converter) {} + + std::tuple + allocateBufferFromSPM(ConversionPatternRewriter &rewriter, Location loc, + Operation *op) const { + static uint64_t spmPointer = 0x10000; + + // create GEPOp for spm address. + MemRefType memRefType = getMemRefResultType(op); + Value spmOffsetOp = rewriter.create( + loc, getIndexType(), rewriter.getI32IntegerAttr(spmPointer)); + Type elementType = typeConverter->convertType(memRefType.getElementType()); + auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + Value spmAddr = rewriter.create(loc, elementPtrType); + + spmAddr = rewriter.create(op->getLoc(), + rewriter.getI64Type(), spmAddr); + spmAddr = rewriter.create(op->getLoc(), rewriter.getI64Type(), + spmAddr, spmOffsetOp); + + spmAddr = rewriter.create(op->getLoc(), elementPtrType, + spmAddr); + Value allocatedPtr = spmAddr; + if (!allocatedPtr) + return std::make_tuple(Value(), Value()); + Value alignedPtr = allocatedPtr; + + // update spm pointer + auto elemCount = memRefType.getNumElements(); + auto bitWidth = memRefType.getElementTypeBitWidth(); + auto allocOp = dyn_cast(op); + if (allocOp.getAlignment().has_value()) + bitWidth = allocOp.getAlignment().value(); + uint64_t totalByte = (elemCount * bitWidth + 7) / 8; + spmPointer += totalByte; + + return std::make_tuple(allocatedPtr, alignedPtr); + } + + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, + Operation *op) const override { + return allocateBufferFromSPM(rewriter, loc, op); + } +}; + +template +struct MemrefLoadOrStoreOpLowering : public ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using OpAdaptor = typename MemrefOp::Adaptor; + + LogicalResult + matchAndRewrite(MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = op.getMemRefType(); + + Value dataPtr = ConvertToLLVMPattern::getStridedElementPtr( + op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); + + // TODO: Add spm offset according the memory space + MemRefDescriptor memRefDescriptor(adaptor.getMemref()); + auto intPtrType = ConvertToLLVMPattern::getIntPtrType( + memRefDescriptor.getElementPtrType().getAddressSpace()); + Value ptrValue = + rewriter.create(op.getLoc(), intPtrType, dataPtr); + + // FIXME: Can only need create once since offset is a const op? + auto spmMemoryOffset = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(spmMappingOffset)); + auto spmMemoryAddr = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + SmallVector({ptrValue, spmMemoryOffset})); + + auto ptrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), + *ConvertToLLVMPattern::getTypeConverter()->getMemRefAddressSpace(type)); + auto spmMemoryAddrPtr = + rewriter.create(op.getLoc(), ptrTy, spmMemoryAddr); + + // Wether need memoryspace cast + if constexpr (std::is_same()) { + + rewriter.replaceOpWithNewOp( + op, op.getType(), spmMemoryAddrPtr, 0, false, op.getNontemporal()); + } else { + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), dataPtr, 0, false, op.getNontemporal()); + } + + return success(); + } +}; + +struct MemRefReinterpretCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = castOp.getSource().getType(); + + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(castOp, {descriptor}); + return success(); + } + +private: + /// Extracts allocated, aligned pointers and offset from a ranked or unranked + /// memref type. In unranked case, the fields are extracted from the + /// underlying ranked descriptor. + void extractPointersAndOffset(Location loc, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &typeConverter, + Value originalOperand, Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr, + Value *offset = nullptr) const { + Type operandType = originalOperand.getType(); + if (isa(operandType)) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + if (offset != nullptr) + *offset = desc.offset(rewriter, loc); + return; + } + + // These will all cause assert()s on unconvertible types. + unsigned memorySpace = *typeConverter.getMemRefAddressSpace( + cast(operandType)); + auto elementPtrType = + LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + + // FIXME: workaround, take memRefDescPtr as naked ptr. + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + *allocatedPtr = underlyingDescPtr; + *alignedPtr = underlyingDescPtr; + + if (offset != nullptr) { + *offset = rewriter.create( + loc, getIndexType(), rewriter.getI32IntegerAttr(0)); + } + } + + LogicalResult convertSourceMemRefToDescriptor( + ConversionPatternRewriter &rewriter, Type srcType, + memref::ReinterpretCastOp castOp, + memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { + MemRefType targetMemRefType = + cast(castOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); + if (!llvmTargetDescriptorTy) + return failure(); + + // Create descriptor. + Location loc = castOp.getLoc(); + MemRefDescriptor desc(*descriptor); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + castOp.getSource(), adaptor.getSource(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Set offset. + if (castOp.isDynamicOffset(0)) + desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); + else + desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); + + // Set sizes and strides. + unsigned dynSizeId = 0; + unsigned dynStrideId = 0; + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + if (castOp.isDynamicSize(i)) + desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); + else + desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); + + if (castOp.isDynamicStride(i)) + desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); + else + desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); + } + *descriptor = desc; + return success(); + } +}; + +/// Materialize the MemRef descriptor represented by the results of +/// ExtractStridedMetadataOp. +class ExtractStridedMetadataOpLowering + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) + return failure(); + + // Create the descriptor. + MemRefDescriptor sourceMemRef(adaptor.getSource()); + Location loc = extractStridedMetadataOp.getLoc(); + Value source = extractStridedMetadataOp.getSource(); + + auto sourceMemRefType = cast(source.getType()); + int64_t rank = sourceMemRefType.getRank(); + SmallVector results; + results.reserve(2 + rank * 2); + + // Base buffer. + Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); + Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); + MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), + cast(extractStridedMetadataOp.getBaseBuffer().getType()), + baseBuffer, alignedBuffer); + results.push_back((Value)dstMemRef); + + // Offset. + results.push_back(sourceMemRef.offset(rewriter, loc)); + + // Sizes. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.size(rewriter, loc, i)); + // Strides. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.stride(rewriter, loc, i)); + + rewriter.replaceOp(extractStridedMetadataOp, results); + return success(); + } +}; + +/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. +class ConvertExtractAlignedPointerAsIndex + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + BaseMemRefType sourceTy = extractOp.getSource().getType(); + + // FIXME: We want allocated ptr instead of aligned ptr. + Value alignedPtr; + if (sourceTy.hasRank()) { + MemRefDescriptor desc(adaptor.getSource()); + alignedPtr = desc.allocatedPtr(rewriter, extractOp->getLoc()); + } else { + auto elementPtrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); + + UnrankedMemRefDescriptor desc(adaptor.getSource()); + Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); + + alignedPtr = UnrankedMemRefDescriptor::allocatedPtr( + rewriter, extractOp->getLoc(), descPtr, elementPtrTy); + } + + rewriter.replaceOpWithNewOp( + extractOp, getTypeConverter()->getIndexType(), alignedPtr); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateTx81MemrefToLLVMConversionPatterns( + RewritePatternSet &patterns, LLVMTypeConverter &converter) { + // clang-format off + patterns.add, + MemrefLoadOrStoreOpLowering>( + converter); + // clang-format on +} \ No newline at end of file diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp new file mode 100644 index 000000000..d37c2824a --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp @@ -0,0 +1,88 @@ +//===------------------- Tx81MemrefToLLVMPass.cpp--------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "tx81-memref-to-llvm" + +using namespace mlir; + +namespace mlir { +namespace triton { +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +class Tx81MemrefToLLVMPass + : public mlir::triton::Tx81MemrefToLLVMBase { + using Tx81MemrefToLLVMBase::Tx81MemrefToLLVMBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + target.addIllegalOp(); + + target.addLegalDialect(); + + target.addLegalOp(); + + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + LLVMTypeConverter llvmTypeConverter(context, options); + triton::populateTx81MemrefToLLVMConversionPatterns(patterns, + llvmTypeConverter); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> triton::createTx81MemrefToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt new file mode 100644 index 000000000..68517188e --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -0,0 +1,23 @@ +add_triton_library(Tx81ToLLVM + Tx81ToLLVM.cpp + + DEPENDS + Tx81ToLLVMConversionPassIncGen + MLIRMemRefToLLVM + + LINK_LIBS PUBLIC + MLIRMemRefToLLVM + MLIRArithDialect + MLIRArithToLLVM + MLIRFuncDialect + MLIRFuncToLLVM + MLIRLLVMDialect + MLIRMemRefDialect + MLIRMemRefToLLVM + MLIRArithToLLVM + MLIRAffineToStandard + MLIRLinalgToStandard + MLIRSCFDialect + MLIRSCFToControlFlow + MLIRTransforms +) diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp new file mode 100644 index 000000000..339197ffc --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp @@ -0,0 +1,208 @@ +//===- KernelArgBufferPass.cpp - Convert kernel args to single buffer -----===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms kernel function signatures by converting multiple +// arguments into a single void* buffer containing all the arguments. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace { + +class KernelArgBufferPass + : public PassWrapper> { +public: + StringRef getArgument() const final { return "kernel-arg-buffer"; } + StringRef getDescription() const final { + return "Convert kernel arguments to a single buffer argument"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + +private: + // Identifies if a function should be processed + bool isKernelFunction(func::FuncOp func); + + // Creates a new function with a single void* argument + func::FuncOp createBufferizedFunction(OpBuilder &builder, func::FuncOp originalFunc); + + // Rewrites the function body to use the argument buffer + void rewriteFunctionBody(func::FuncOp originalFunc, func::FuncOp newFunc); +}; + +bool KernelArgBufferPass::isKernelFunction(func::FuncOp func) { + // For this example, we'll identify kernel functions by their name + // containing "_kernel". In a real implementation, you might use attributes + // or more sophisticated detection. + return func.getName().contains("_kernel"); +} + +func::FuncOp KernelArgBufferPass::createBufferizedFunction(OpBuilder &builder, + func::FuncOp originalFunc) { + // Create a new function type with a single void* argument + auto voidPtrType = LLVM::LLVMPointerType::get(builder.getContext()); + auto newFuncType = FunctionType::get(originalFunc.getContext(), + {voidPtrType}, + originalFunc.getFunctionType().getResults()); + + // Create the new function with the same name but new type + auto newFunc = func::FuncOp::create(originalFunc.getLoc(), originalFunc.getName(), + newFuncType); + + // Copy over all attributes except those related to the function type + for (const auto &attr : originalFunc->getAttrs()) { + if (attr.getName() != "function_type" && attr.getName() != "arg_attrs" && + attr.getName() != "res_attrs") { + newFunc->setAttr(attr.getName(), attr.getValue()); + } + } + + return newFunc; +} + +void KernelArgBufferPass::rewriteFunctionBody(func::FuncOp originalFunc, + func::FuncOp newFunc) { + if (originalFunc.empty()) + return; + + Block &oldEntryBlock = originalFunc.getBlocks().front(); + Block &newEntryBlock = newFunc.getBlocks().front(); + + OpBuilder builder(&newEntryBlock, newEntryBlock.begin()); + Location loc = originalFunc.getLoc(); + + Value argsBuffer = newEntryBlock.getArgument(0); + SmallVector extractedArgs; + + // Offset tracking for buffer access + int64_t currentOffset = 0; + // Size of scalar values in bytes (specified as 8 bytes) + const int64_t scalarSize = 8; + + // Process each original argument + for (auto argIndex : llvm::seq(0, originalFunc.getNumArguments())) { + Type argType = originalFunc.getArgument(argIndex).getType(); + Value loadedArg; + + // Handle pointer types (like uint64_t*) + if (auto ptrType = dyn_cast(argType)) { + // For pointer types, we load the pointer value itself from the buffer + auto offsetValue = builder.create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); + + // Get pointer to the current position in args buffer + auto elementPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), argsBuffer, + ArrayRef{offsetValue}); + + // Cast to pointer-to-pointer type + auto castedPtr = builder.create( + loc, LLVM::LLVMPointerType::get(ptrType), elementPtr); + + // Load the pointer + loadedArg = builder.create(loc, castedPtr); + + // Increment offset (pointers are 8 bytes) + currentOffset += scalarSize; + } + // Handle scalar types (like int64_t, int) + else { + auto offsetValue = builder.create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); + + // Get pointer to the current position in args buffer + auto elementPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), argsBuffer, + ArrayRef{offsetValue}); + + // Cast to appropriate pointer type + auto castedPtr = builder.create( + loc, LLVM::LLVMPointerType::get(argType), elementPtr); + + // Load the scalar value + loadedArg = builder.create(loc, castedPtr); + + // Increment offset (all scalars use 8 bytes as specified) + currentOffset += scalarSize; + } + + extractedArgs.push_back(loadedArg); + } + + // Clone the original function body, replacing uses of old arguments + auto &oldRegion = originalFunc.getBody(); + auto &newRegion = newFunc.getBody(); + + // Move operations from old entry block to new entry block + for (auto &op : oldEntryBlock.getOperations()) { + if (&op == &oldEntryBlock.back() && op.hasTrait()) { + builder.clone(op); + } else { + auto clonedOp = builder.clone(op); + + // Replace uses of old arguments with new extracted values + for (unsigned i = 0; i < originalFunc.getNumArguments(); ++i) { + Value oldArg = oldEntryBlock.getArgument(i); + clonedOp->replaceUsesOfWith(oldArg, extractedArgs[i]); + } + } + } +} + +void KernelArgBufferPass::runOnOperation() { + ModuleOp module = getOperation(); + OpBuilder builder(module.getContext()); + + // Collect functions to process + SmallVector kernelFuncs; + for (auto func : module.getOps()) { + if (isKernelFunction(func)) { + kernelFuncs.push_back(func); + } + } + + // Process each kernel function + for (auto func : kernelFuncs) { + // Create new function with bufferized signature + builder.setInsertionPointAfter(func); + auto newFunc = createBufferizedFunction(builder, func); + + // Add entry block to the new function + newFunc.addEntryBlock(); + + // Rewrite function body to use the argument buffer + rewriteFunctionBody(func, newFunc); + + // Replace the old function with the new one + func.erase(); + } +} + +} // namespace + +std::unique_ptr createKernelArgBufferPass() { + return std::make_unique(); +} + +// Pass registration +namespace { +#define GEN_PASS_REGISTRATION +#include "KernelArgBufferPass.h.inc" +} // namespace \ No newline at end of file diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp new file mode 100644 index 000000000..66e198f54 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -0,0 +1,1326 @@ +//===--------------------- Tx81ToLLVM.cpp ---------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements the patterns to convert operations from tx dialect to +// LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "tx81-to-llvm" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +namespace { +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// +// Crt func name +const char addVVFuncName[] = "__AddVV"; +const char subVVFuncName[] = "__SubVV"; +const char mulVVFuncName[] = "__MulVV"; +const char divVVFuncName[] = "__DivVV"; +const char addVSFuncName[] = "__AddVS"; +const char subVSFuncName[] = "__SubVS"; +const char mulVSFuncName[] = "__MulVS"; +const char divVSFuncName[] = "__DivVS"; +const char reduceSumFuncName[] = "__ReduceSum"; +const char reduceMaxFuncName[] = "__ReduceMax"; +const char reduceMinFuncName[] = "__ReduceMin"; +const char fp16ToFp32FuncName[] = "__FP16_FP32"; +const char int16ToFp16FuncName[] = "__INT16_FP16"; +const char int16ToFp32FuncName[] = "__INT16_FP32"; +const char int32ToFp16FuncName[] = "__INT32_FP16"; +const char int32ToFp32FuncName[] = "__INT32_FP32"; +const char fp16ToInt8FuncName[] = "__FP16_INT8"; +const char fp16ToInt16FuncName[] = "__FP16_INT16"; +const char fp16ToInt32FuncName[] = "__FP16_INT32"; +const char fp32ToInt8FuncName[] = "__FP32_INT8"; +const char fp32ToInt16FuncName[] = "__FP32_INT16"; +const char fp32ToInt32FuncName[] = "__FP32_INT32"; + +// Function to declare Tx81 runtime function +Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, + StringRef name, Type resultType, + ArrayRef argumentTypes) { + // Check if the function already exists + Operation *funcOp = module.lookupSymbol(name); + if (funcOp) + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); + + // Create function type + Type funcType = LLVM::LLVMFunctionType::get(resultType, argumentTypes, + /*isVarArg=*/false); + + // Create a function declaration + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(module.getBody()); + + builder.create(loc, name, funcType, + LLVM::Linkage::External); + + builder.restoreInsertionPoint(ip); + + // Return function pointer + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), name); +} + +static Value adjustElemCountType(ConversionPatternRewriter &rewriter, + Location loc, Value elemCount) { + Value newElemCount = elemCount; + if (isa(elemCount.getType())) { + newElemCount = rewriter.create( + loc, rewriter.getI32Type(), elemCount); + } else if (isa(elemCount.getType())) { + auto elemCountType = dyn_cast(elemCount.getType()); + if (elemCountType.isInteger(64)) + newElemCount = rewriter.create( + loc, rewriter.getI32Type(), elemCount); + } + return newElemCount; +} + +static Value castIndexToInt32(ConversionPatternRewriter &rewriter, Location loc, + Value indexOp) { + return rewriter.create(loc, rewriter.getI32Type(), + indexOp); +} + +//===----------------------------------------------------------------------===// +// Arith Operation Conversion Patterns +//===----------------------------------------------------------------------===// + +// Convert constant operations to LLVM constants +struct ConstantOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the constant value + auto constAttr = op.getValue(); + + // Get the result type + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); + + // Handle different attribute types + if (auto intAttr = dyn_cast(constAttr)) { + // Convert integer attribute + rewriter.replaceOpWithNewOp(op, resultType, intAttr); + return success(); + } else if (auto floatAttr = dyn_cast(constAttr)) { + // Convert float attribute + rewriter.replaceOpWithNewOp(op, resultType, floatAttr); + return success(); + } else if (auto boolAttr = dyn_cast(constAttr)) { + // Convert bool attribute to i1 + rewriter.replaceOpWithNewOp( + op, resultType, + rewriter.getIntegerAttr(resultType, boolAttr.getValue())); + return success(); + } + + return failure(); + } +}; + +// Convert arith.index_cast to appropriate LLVM conversions +struct IndexCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get source and result types + auto srcType = adaptor.getIn().getType(); + auto dstType = getTypeConverter()->convertType(op.getResult().getType()); + + // Convert from index to specific integer type + if (isa(srcType) && isa(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + return success(); + } + + // Convert from specific integer type to index + if (isa(srcType) && isa(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + return success(); + } + + // Handle integer to integer casts + if (isa(srcType) && isa(dstType)) { + unsigned srcWidth = cast(srcType).getWidth(); + unsigned dstWidth = cast(dstType).getWidth(); + + if (srcWidth < dstWidth) { + // Sign extend if source is signed, zero extend otherwise + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); + } else if (srcWidth > dstWidth) { + // Truncate + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn()); + } else { + // Same width, just pass through + rewriter.replaceOp(op, adaptor.getIn()); + } + return success(); + } + + return failure(); + } +}; + +// Convert arith.addi to LLVM add +struct AddIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +// Convert arith.muli to LLVM mul +struct MulIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Tx81 Operation Conversion Patterns +//===----------------------------------------------------------------------===// + +// Convert tx81.rdma to LLVM call to crt __Rdma function +struct RdmaOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::RdmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Rdma runtime function if not already declared + /* + void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int + shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int + *strides, uint32_t fmt) + */ + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src + i8PtrTy, // target + i32Ty, // shape_n + i32Ty, // shape_h + i32Ty, // shape_w + i32Ty, // shape_c + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Rdma", + i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + // Get the operands + Value target = adaptor.getTarget(); + target = rewriter.create(op.getLoc(), i8PtrTy, target); + + ValueRange shape = adaptor.getShape(); + Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + + ValueRange strides = adaptor.getStrides(); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __Rdma + auto call = rewriter.create( + op.getLoc(), TypeRange{i8PtrTy}, "__Rdma", // funcPtr, + ValueRange{src, target, shape0, shape1, shape2, shape3, stride0, + stride1, stride2, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.wdma to LLVM call to __Wdma function +struct WdmaOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::WdmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Wdma runtime function if not already declared + /* + void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int + shape_w, int shape_c, int stride_n, int stride_h, int stride_w, int + *strides, uint32_t fmt) + */ + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src + i8PtrTy, // target + i32Ty, // shape_n + i32Ty, // shape_h + i32Ty, // shape_w + i32Ty, // shape_c + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Wdma", + i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + // Get the operands + Value target = adaptor.getTarget(); + + // Need to bitcast src to i8* + target = rewriter.create(op.getLoc(), i8PtrTy, target); + + ValueRange shape = adaptor.getShape(); + Value shape0 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value shape1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value shape2 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value shape3 = castIndexToInt32(rewriter, op->getLoc(), shape[3]); + + ValueRange strides = adaptor.getStrides(); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __Wdma + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Wdma", // funcPtr, + ArrayRef{src, target, shape0, shape1, shape2, shape3, stride0, + stride1, stride2, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.mask_move to LLVM call to __MaskMove function +struct MaskMoveOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::MaskMoveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __MaskMove runtime function if not already declared + // Signature: void* __MaskMove(void* source, void* target, uint32_t + // elem_count, int32_t* masks, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // source + i8PtrTy, // target + i32Ty, // elem_count + i32PtrTy, // masks + i32Ty // fmt + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__MaskMove", i8PtrTy, argTypes); + + // Get the operands + Value src = adaptor.getSource(); + + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + Value target = adaptor.getTarget(); + + // Need to bitcast src to i8* + target = rewriter.create(op.getLoc(), i8PtrTy, target); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); + + // Handle mask arrays + // For simplicity, we'll create empty arrays + Value nullPtr = rewriter.create(op.getLoc(), i32PtrTy); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call to __MaskMove + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__MaskMove", // funcPtr, + ArrayRef{src, target, elemCount, nullPtr, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.binary op to LLVM call +template +struct ReduceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: + // __ReduceSum(uint64_t *src, uint64_t *dst, uint32_t dim, uint16_t src_n, + // uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i32Ty, + i32Ty, i32Ty, i32Ty, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value src = adaptor.getSrc(); + // Need to bitcast src to i8* + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value srcB = adaptor.getSrc(); + Value dst = adaptor.getDst(); + // Need to bitcast src to i8* + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); + + // Convert dim attribute to Value + Value dim = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getDim())); + + // Convert shape attribute to Value + Value shape_n = + rewriter.create(op.getLoc(), i32Ty, op.getShape()[0]); + Value shape_h = + rewriter.create(op.getLoc(), i32Ty, op.getShape()[1]); + Value shape_w = + rewriter.create(op.getLoc(), i32Ty, op.getShape()[2]); + Value shape_c = + rewriter.create(op.getLoc(), i32Ty, op.getShape()[3]); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{src, dst, dim, shape_n, shape_h, shape_w, shape_c, + fmt}); + + // Erase the old op + rewriter.eraseOp(op); + + return success(); + } +}; + +// Convert tx81.binary op to LLVM call +template +struct ElementWiseOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + // using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Add(void* a, void* b, void* out, uint32_t elem_count, + // uint32_t rnd_mode, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, + + i32Ty, i32Ty, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle round attribute + Value rnd_mode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getRndMode())); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, rnd_mode, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// FIXME: Use trait to refactor the BinaryVSOpConversion and +// ElementWiseOpConversion +template +struct BinaryVSOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Add(void* a, void* b, void* out, uint32_t elem_count, + // uint32_t rnd_mode, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i32Ty, i8PtrTy, + i32Ty, i32Ty, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + + Value srcB = adaptor.getValue(); + + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle round attribute + Value rnd_mode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getRndMode())); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, rnd_mode, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.NormalConvertOp op to LLVM +template +struct NormalConvertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename NormalConvertOp::Adaptor; + + LogicalResult + matchAndRewrite(NormalConvertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + // uint64_t dst_addr, uint32_t elem_count); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + Value output = adaptor.getOutput(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, output, elemCount}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.RoundConvertOp op to LLVM +template +struct RoundConvertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename RoundConvertOp::Adaptor; + + LogicalResult + matchAndRewrite(RoundConvertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + // uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + Value output = adaptor.getOutput(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + Value rnd_mode = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getRndMode())); + + // Bitcast all pointers to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + output = rewriter.create(op.getLoc(), i8PtrTy, output); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, output, elemCount, rnd_mode}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.gemm to LLVM call to __Gemm function +struct GemmOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::GemmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Gemm runtime function if not already declared + // Signature: void* __Gemm(void* a, void* b, void* bias, int32_t* dims, + // void* psum, bool trans_a, bool trans_b, + // uint32_t batch_a, uint32_t batch_b, bool en_relu, + // bool en_bias, bool en_neg_scale, void* neg_scale, + // bool en_pos_scale, void* pos_scale, + // uint32_t src_fmt, uint32_t dst_fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i64Ty = rewriter.getI64Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i1Ty = rewriter.getI1Type(); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, i8PtrTy, i8PtrTy, i32PtrTy, i8PtrTy, i1Ty, + i1Ty, i32Ty, i32Ty, i1Ty, i1Ty, i1Ty, + i8PtrTy, i1Ty, i8PtrTy, i32Ty, i32Ty}; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Gemm", + i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getSrcA(); + Value srcB = adaptor.getSrcB(); + Value srcBias = adaptor.getSrcBias(); + Value psumAddr = adaptor.getPsumAddr(); + Value srcNegScale = adaptor.getSrcNegScale(); + Value srcPosScale = adaptor.getSrcPosScale(); + + // Bitcast all pointers to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + srcBias = rewriter.create(op.getLoc(), i8PtrTy, srcBias); + psumAddr = + rewriter.create(op.getLoc(), i8PtrTy, psumAddr); + srcNegScale = + rewriter.create(op.getLoc(), i8PtrTy, srcNegScale); + srcPosScale = + rewriter.create(op.getLoc(), i8PtrTy, srcPosScale); + + // Handle dims array - need to convert from attribute to runtime array + auto dimsAttr = op.getDims(); + SmallVector dimsValues; + for (auto dimAttr : dimsAttr) + dimsValues.push_back(cast(dimAttr).getInt()); + + // Allocate memory for the dims array + Value dimsArraySize = rewriter.create( + op.getLoc(), i64Ty, + rewriter.getI64IntegerAttr(dimsValues.size() * sizeof(int64_t))); + + // Use malloc to allocate memory for dims array + Value mallocFunc = declareTx81Function(module, rewriter, op.getLoc(), + "malloc", i8PtrTy, {i64Ty}); + Value dimsArray = + rewriter + .create(op.getLoc(), i8PtrTy, "malloc", // mallocFunc, + ArrayRef{dimsArraySize}) + .getResult(); + + // Cast to i64* + Value dimsArrayI64Ptr = + rewriter.create(op.getLoc(), i64PtrTy, dimsArray); + + // Store each dimension in the array + for (size_t i = 0; i < dimsValues.size(); i++) { + // Create the index + Value idx = rewriter.create( + op.getLoc(), i64Ty, rewriter.getI32IntegerAttr(i)); + + // Create GEP to get pointer to array element + Value elemPtr = rewriter.create( + op.getLoc(), i64PtrTy, i64Ty, dimsArrayI64Ptr, ArrayRef{idx}); + + // Create the dimension value + Value dimValue = rewriter.create( + op.getLoc(), i64Ty, rewriter.getI64IntegerAttr(dimsValues[i])); + + // Store the value + rewriter.create(op.getLoc(), dimValue, elemPtr); + } + + // Convert boolean attributes + Value transA = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcA())); + Value transB = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcB())); + Value enLeakyRelu = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnLeakyRelu())); + Value enBias = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnBias())); + Value enNegScale = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnNegScale())); + Value enPosScale = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnPosScale())); + + // Convert integer attributes + Value batchA = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getBatchSrcA())); + Value batchB = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getBatchSrcB())); + Value srcFmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getSrcFmt())); + Value dstFmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getDstFmt())); + + // Create the call to __Gemm + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Gemm", // funcPtr, + ArrayRef{srcA, srcB, srcBias, dimsArrayI64Ptr, psumAddr, transA, + transB, batchA, batchB, enLeakyRelu, enBias, enNegScale, + srcNegScale, enPosScale, srcPosScale, srcFmt, dstFmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Convert tx81.memset to LLVM call to __Memset function +struct MemsetOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::MemsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Memset runtime function if not already declared + // Signature: void* __Memset(void* dst, int64_t value, uint32_t elem_count, + // int32_t* strides, int32_t* iterations, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Ty = rewriter.getI64Type(); + auto i32Ty = rewriter.getI32Type(); + auto i32PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i32Ty, i32Ty, + i32PtrTy, i32PtrTy, i16Ty}; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__Memset", i8PtrTy, argTypes); + + // Get operands + Value src = adaptor.getSrc(); + Value value = adaptor.getValue(); + Value elemCount = adaptor.getElemCount(); + elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); + + // Handle strides and iterations arrays + // For simplicity, we'll create null pointers + Value nullPtr = rewriter.create(op.getLoc(), i32PtrTy); + + src = rewriter.create(op.getLoc(), i8PtrTy, src); + + // Convert fmt attribute to Value + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call to __Memset + auto call = rewriter.create( + op.getLoc(), i8PtrTy, "__Memset", // funcPtr, + ArrayRef{src, value, elemCount, nullPtr, nullPtr, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +// Conversion pattern for linalg.fill operation with tensor arguments +struct LinalgFillOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinalgFillOpConversion(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(linalg::FillOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The operation should have tensor as output + if (op.getOutputs().size() != 1) { + return rewriter.notifyMatchFailure(op, "expects single output tensor"); + } + + // Check if the output is a tensor type + Value outputTensor = op.getOutputs()[0]; + auto tensorType = dyn_cast(outputTensor.getType()); + if (!tensorType) { + return rewriter.notifyMatchFailure(op, "expects ranked tensor type"); + } + + // Check for static shape + if (!tensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic shapes not yet supported"); + } + + auto context = rewriter.getContext(); + auto loc = op.getLoc(); + Value value = adaptor.getInputs()[0]; + + // Get the element type + Type elemType = tensorType.getElementType(); + + // Convert the tensor type to the LLVM pointer type + auto llvmPtrType = dyn_cast(typeConverter->convertType(tensorType)); + if (!llvmPtrType) { + return rewriter.notifyMatchFailure( + op, "failed to convert tensor type to LLVM pointer type"); + } + + // Calculate total number of elements + int64_t totalElements = 1; + for (int64_t dim : tensorType.getShape()) { + totalElements *= dim; + } + + // Get index type + auto indexType = rewriter.getI64Type(); + + // Implement the following steps: + // 1. Allocate memory for the tensor + // 2. Fill it using memset if applicable + // 3. Return the pointer as the result + + // Calculate element size in bytes + int64_t elemSizeInBytes = 0; + if (auto intType = dyn_cast(elemType)) { + elemSizeInBytes = + (intType.getWidth() + 7) / 8; // Round up to nearest byte + } else if (auto floatType = dyn_cast(elemType)) { + elemSizeInBytes = + (floatType.getWidth() + 7) / 8; // Round up to nearest byte + } else { + return rewriter.notifyMatchFailure(op, "unsupported element type"); + } + + // Calculate total size in bytes + auto totalSizeInBytes = totalElements * elemSizeInBytes; + auto totalSizeVal = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(totalSizeInBytes)); + + // Allocate memory + auto mallocFunc = + getOrInsertMalloc(rewriter, op->getParentOfType()); + auto allocated = rewriter.create( + loc, LLVM::LLVMPointerType::get(context), mallocFunc, + ArrayRef{totalSizeVal}); + + auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); + + // Cast the allocated memory to the appropriate pointer type + auto castPtr = rewriter.create(loc, llvmPtrType, + allocated.getResult()); + + // Check if we can use memset for filling + bool useMemset = false; + Value byteValue; + + // For memset to work correctly, we need to have a consistent byte pattern + if (auto constOp = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) { + // For integer constants + auto intVal = intAttr.getInt(); + // Check if all bytes in the pattern are the same + bool allBytesEqual = true; + uint8_t firstByte = intVal & 0xFF; + for (unsigned i = 1; i < elemSizeInBytes; i++) { + if (((intVal >> (i * 8)) & 0xFF) != firstByte) { + allBytesEqual = false; + break; + } + } + + if (allBytesEqual) { + useMemset = true; + byteValue = rewriter.create( + loc, rewriter.getIntegerType(8), + rewriter.getIntegerAttr(rewriter.getIntegerType(8), firstByte)); + } + } else if (auto floatAttr = dyn_cast(constOp.getValue())) { + // For floating point constants + if (floatAttr.getValue().isZero()) { + // Zero float can use memset with zero byte value + useMemset = true; + byteValue = rewriter.create( + loc, rewriter.getIntegerType(8), rewriter.getI8IntegerAttr(0)); + } + } + } + + if (useMemset) { + // Use memset for filling + auto memsetFunc = + getOrInsertMemset(rewriter, op->getParentOfType()); + rewriter.create( + loc, llvmVoidPtr, memsetFunc, + ArrayRef{castPtr, byteValue, totalSizeVal}); + } else { + // Create a loop to manually fill the tensor with the value + // We'll use SCF dialect for structured loops + auto llvmElemType = typeConverter->convertType(elemType); + + // Create loop initialization + auto zero = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(0)); + auto upperBound = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(totalElements)); + auto one = rewriter.create( + loc, indexType, rewriter.getI64IntegerAttr(1)); + + // Create the fill loop + auto loopOp = + rewriter.create(loc, zero, upperBound, one, ValueRange{}); + + // Set insertion point inside the loop + rewriter.setInsertionPointToStart(loopOp.getBody()); + + // Calculate pointer for the current element + auto currentPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), castPtr, + ArrayRef({loopOp.getInductionVar()})); + + // Store the fill value to the current memory location + rewriter.create(loc, value, currentPtr); + + // Reset insertion point after the loop + rewriter.setInsertionPointAfter(loopOp); + } + + // Replace the original op with the casted pointer + rewriter.replaceOp(op, castPtr); + return success(); + } + +private: + // Helper to get or insert malloc function declaration + FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, + ModuleOp module) const { + auto context = rewriter.getContext(); + auto mallocName = "malloc"; + if (module.lookupSymbol(mallocName)) { + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Create malloc function declaration + auto llvmVoidPtr = LLVM::LLVMPointerType::get(context); + auto mallocType = + LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), mallocName, mallocType); + + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Helper to get or insert memset function declaration + FlatSymbolRefAttr getOrInsertMemset(PatternRewriter &rewriter, + ModuleOp module) const { + auto context = rewriter.getContext(); + auto memsetName = "memset"; + if (module.lookupSymbol(memsetName)) { + return SymbolRefAttr::get(rewriter.getContext(), memsetName); + } + + // Create memset function declaration + auto voidPtrType = LLVM::LLVMPointerType::get(context); + auto memsetType = LLVM::LLVMFunctionType::get( + context, + voidPtrType, // memset returns the destination pointer + ArrayRef{voidPtrType, rewriter.getI8Type(), + rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), memsetName, memsetType); + + return SymbolRefAttr::get(rewriter.getContext(), memsetName); + } +}; + +// Conversion pattern for tensor.empty operation +class TensorEmptyOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + TensorEmptyOpConversion(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the result tensor type + TensorType resultType = op.getType(); + + // Verify we can handle this tensor type + if (!resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic shapes not yet supported"); + } + + // Convert the tensor type to LLVM pointer type + auto llvmPtrType = dyn_cast( + getTypeConverter()->convertType(resultType)); + if (!llvmPtrType) { + return rewriter.notifyMatchFailure( + op, "failed to convert tensor type to LLVM pointer type"); + } + + // Get element type + Type elementType = resultType.getElementType(); + + // Create LLVM operations to allocate memory + // 1. Calculate the total allocation size in bytes + auto loc = op.getLoc(); + int64_t totalElements = 1; + for (int64_t dim : resultType.getShape()) { + totalElements *= dim; + } + + auto elementSize = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.getI64IntegerAttr(getElementTypeSize(elementType))); + + auto totalSize = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(totalElements)); + + auto allocSize = rewriter.create(loc, rewriter.getI64Type(), + totalSize, elementSize); + + // 2. Allocate memory using malloc + auto mallocFunc = + getOrInsertMalloc(rewriter, op->getParentOfType()); + auto allocated = rewriter.create(loc, llvmPtrType, mallocFunc, + ArrayRef{allocSize}); + + // Replace the tensor.empty operation with our allocation + rewriter.replaceOp(op, allocated.getResult()); + return success(); + } + +private: + // Helper to get element type size in bytes + int64_t getElementTypeSize(Type type) const { + if (auto floatType = dyn_cast(type)) { + return floatType.getWidth() / 8; + } else if (auto intType = dyn_cast(type)) { + return intType.getWidth() / 8; + } + // Default for other types + return 1; + } + + // Helper to get or insert malloc function declaration + FlatSymbolRefAttr getOrInsertMalloc(PatternRewriter &rewriter, + ModuleOp module) const { + auto mallocName = "malloc"; + if (module.lookupSymbol(mallocName)) { + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } + + // Create malloc function declaration + auto llvmVoidPtr = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto mallocType = + LLVM::LLVMFunctionType::get(llvmVoidPtr, {rewriter.getI64Type()}, + /*isVarArg=*/false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module->getLoc(), mallocName, mallocType); + + return SymbolRefAttr::get(rewriter.getContext(), mallocName); + } +}; + +// The conversion pass +class Tx81ToLLVMPass : public Tx81ToLLVMBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + + // Setup LLVM lowering options object which should live across the call to + // applyFull/PartialConversion. + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + + // Setup conversion target + target.addLegalDialect(); + // Handle the tx81 op to llvm.call and support kcore load/store op's spm + // offset + target.addIllegalDialect(); + + // Setup rewrite patterns + RewritePatternSet patterns(context); + + // NOTE: LLVMTypeConverter should be enough for MLIR core dialects. + LLVMTypeConverter llvmTypeConverter(context, options); + + // Add the Tx81 to LLVM conversion patterns + // clang-format off + patterns.add, + NormalConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + ReduceOpConversion, + ReduceOpConversion, + ReduceOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + BinaryVSOpConversion, + RdmaOpConversion, + WdmaOpConversion, + MaskMoveOpConversion, + GemmOpConversion, + MemsetOpConversion>( + context); + // clang-format on + + // Add call op conversion + populateCallOpTypeConversionPattern(patterns, llvmTypeConverter); + + // Add return op conversion + populateReturnOpTypeConversionPattern(patterns, llvmTypeConverter); + + // Apply the conversion + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> triton::createTx81ToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp new file mode 100644 index 000000000..0507f6ab6 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp @@ -0,0 +1,80 @@ +//===--------------------- Tx81ToLLVMPass.cpp -----------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include +#include +#include + +#define DEBUG_TYPE "tx81-to-llvm" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h.inc" + +namespace { + + +class Tx81ToLLVMPass : public Tx81ToLLVMBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + + // Setup LLVM lowering options object which should live across the call to + // applyFull/PartialConversion. + LowerToLLVMOptions options(context); + options.useBarePtrCallConv = false; + + // Setup conversion target + target.addLegalDialect(); + target.addIllegalDialect(); + + // Setup rewrite patterns + RewritePatternSet patterns(context); + + // NOTE: LLVMTypeConverter should be enough for MLIR core dialects. + TensorToLLVMTypeConverter converter(context, options); + + triton::populateTx81ToLLVMConversionPatterns(patterns, target, converter); + + // Apply the conversion + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +triton::createTx81ToLLVMPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Dialect/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4c1e72494 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) +add_subdirectory(MagicKernel) +add_subdirectory(TsingMicroTx81) diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt new file mode 100644 index 000000000..41f904167 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(MagicKernelIR + IR/MagicKernelDialect.cpp + Transforms/BufferizableOpInterfaceImpl.cpp + + DEPENDS + MagicKernelTableGen + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp new file mode 100644 index 000000000..d71761179 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/IR/MagicKernelDialect.cpp @@ -0,0 +1,33 @@ +//===------------------- MagicKernelDialect.cpp ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" + +using namespace mlir; +using namespace mlir::mk; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void MagicKernelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "magic-kernel/Dialect/IR/MagicKernelOps.cpp.inc" + >(); + // TODO: Add BufferizableOpInterface to all ops that can be bufferized + declarePromisedInterfaces(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "magic-kernel/Dialect/IR/MagicKernelOps.cpp.inc" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..cd7cc2688 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,122 @@ +//===- BufferizableOpInterfaceImpl.cpp ----------------------------------- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This file implements mk dialect DestinationStyleOp BufferizableOpInterface. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; +using namespace mlir::bufferization; + +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult +bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, + DestinationStyleOpInterface op, + const BufferizationOptions &options) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasPureTensorSemantics()) + return op->emitError() << "op does not have pure tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{}, + op->getAttrs()); + + Operation *newOp = Operation::create(state); + + // We don't want the rewriter tracks an incomplete operation, so insert new + // operation after op was fully constructed. + rewriter.insert(newOp); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +/// Bufferization of mk ops. Replace with a new mk op that operates entirely on +/// memrefs. +template +struct MKOpInterface + : public DstBufferizableOpInterfaceExternalModel, + OpTy> { + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + return bufferizeDestinationStyleOpInterface( + rewriter, cast(op), options); + } +}; + +/// Helper structure that iterates over all mkOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template struct MKOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; + +void mlir::mk::registerBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, mlir::mk::MagicKernelDialect *dialect) { + // TODO: Register all mk ops. + MKOpInterfaceHelper::registerOpInterface(ctx); + }); +} \ No newline at end of file diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 000000000..27aac38fa --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonStructuredIR + TritonStructuredOps.cpp + TritonStructuredDialect.cpp + + DEPENDS + TritonStructuredTableGen + + LINK_LIBS PUBLIC + TritonIR + MLIRIR + ) diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp new file mode 100644 index 000000000..2af19b8a2 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp @@ -0,0 +1,22 @@ +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +using namespace mlir; +using namespace mlir::tts; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void TritonStructuredDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp new file mode 100644 index 000000000..cf55d834a --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -0,0 +1,179 @@ +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +using namespace mlir; +using namespace mlir::tts; + +namespace mlir { +namespace tts { + +void MakeTensorPtrOp::build(OpBuilder &b, OperationState &state, Value base, + ArrayRef sizes, + ArrayRef strides, + ArrayRef offsets, + ArrayRef shape, + ArrayRef order) { + SmallVector staticStrides, staticOffsets, staticShape; + SmallVector dynamicStrides, dynamicOffsets, dynamicShape; + + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); + + Type resType; + auto basePtr = cast(base.getType()); + auto elemType = basePtr.getPointeeType(); + // non-block pointer + if (order.empty()) { + resType = RankedTensorType::get(sizes, basePtr); + } + // block pointer + else { + resType = triton::PointerType::get(RankedTensorType::get(sizes, elemType), + basePtr.getAddressSpace()); + } + + build(b, state, resType, base, sizes, dynamicStrides, dynamicOffsets, + dynamicShape, b.getDenseI64ArrayAttr(staticStrides), + b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticShape), order); +} + +void LoadOp::build(OpBuilder &b, OperationState &state, Value ptr, + ArrayRef dims, Value other) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + // non-block pointer type + auto ptrTensorType = dyn_cast(ptr.getType()); + // block pointer type + auto tensorPtrType = dyn_cast(ptr.getType()); + + Type resType; + if (ptrTensorType) { + auto ptrType = cast(ptrTensorType.getElementType()); + auto elemType = ptrType.getPointeeType(); + resType = RankedTensorType::get(ptrTensorType.getShape(), elemType); + + } else if (tensorPtrType) { + auto tensorType = cast(tensorPtrType.getPointeeType()); + resType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType()); + } + build(b, state, resType, ptr, dynamicDims, b.getDenseI64ArrayAttr(staticDims), + other); +} + +void StoreOp::build(OpBuilder &b, OperationState &state, Value ptr, Value value, + ArrayRef dims) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + build(b, state, ptr, value, dynamicDims, b.getDenseI64ArrayAttr(staticDims)); +} + +LogicalResult GetStructuredStateOp::verify() { + auto expectedOffsetAndStrideTypes = + getOffsetAndStrideTypes(getContext(), getInput().getType()); + + if (!expectedOffsetAndStrideTypes.has_value()) { + return failure(); + } + + auto [expectedOffsetTypes, expectedStrideTypes] = + *expectedOffsetAndStrideTypes; + + return success(expectedOffsetTypes.size() == getOffsets().size() && + llvm::equal(expectedOffsetTypes, getOffsets().getTypes()) && + expectedStrideTypes.size() == getStrides().size() && + llvm::equal(expectedStrideTypes, getStrides().getTypes())); +} + +void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, + Value val) { + auto type = val.getType(); + + // Builder cannot fail, so we default to empty offset and stride types. + // The invalid op will be rejected by the verifier later. + auto [offsetTypes, strideTypes] = + getOffsetAndStrideTypes(b.getContext(), type) + .value_or(std::make_pair(SmallVector{}, SmallVector{})); + + build(b, state, val.getType(), offsetTypes, strideTypes, val); +} + +std::optional, SmallVector>> +GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, Type type) { + auto sizes = getOffsetAndStrideSegmentSizes(type); + if (!sizes.has_value()) { + return std::nullopt; + } + return std::make_pair( + SmallVector(sizes->first, IndexType::get(context)), + SmallVector(sizes->second, IndexType::get(context))); +} + +std::optional> +GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type type) { + int32_t offsetSegmentSize = 0; + int32_t strideSegmentSize = 0; + + if (auto tensorType = llvm::dyn_cast(type)) { + if (tensorType.getElementType().isIntOrIndex()) { + // Tensors of offsets + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + // Unstructured pointers (tensor>) + // Each tensor of rank k gets k values for its offsets and k values for + // its strides, all of which has Index type. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } + } + // Block pointers (!tt.ptr> or !tt.ptr) + else if (auto ptrType = llvm::dyn_cast(type)) { + if (auto tensorType = + llvm::dyn_cast(ptrType.getPointeeType())) { + // Each tensor of rank k gets k values for its offsets and k values for + // its strides, all of which has Index type. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else { + // The only relevant state that can be updated in loops for scalar + // pointers are offset. No need to include stride here. + offsetSegmentSize = 1; + } + } else { + return std::nullopt; + } + + return std::make_pair(offsetSegmentSize, strideSegmentSize); +} + +} // namespace tts +} // namespace mlir diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..c7f1c8174 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,134 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +using namespace mlir; +using namespace linalg; +using namespace mlir::bufferization; + +// +// This file implements the bufferizable interface for TritonTilingExtOps. +// The interface is required for bufferization (i.e: converting from tensors to +// memrefs). +// Since the bufferization semantics of TritonTilingExtOps are identical to +// linalg ops, the implementation was borrowed almost verbatim from +// mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +// with the exception that the code to handle linalg's region has been removed. +// (the original implementation is in an anonymous namespace, so we cannot +// reuse) +// +namespace { + +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult bufferizeTritonTilingExtDestinationStyleOpInterface( + RewriterBase &rewriter, DestinationStyleOpInterface op, + const BufferizationOptions &options) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasPureTensorSemantics()) + return op->emitError() << "op does not have tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + clone(rewriter, op, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +template +struct TritonTilingExtOpInterface + : public DstBufferizableOpInterfaceExternalModel< + TritonTilingExtOpInterface, OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is read if it is used in the computation. + return cast(op).isDpsInput(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is written to if it is not an input/init. + return cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + return bufferizeTritonTilingExtDestinationStyleOpInterface( + rewriter, cast(op), options); + } +}; + +template struct TritonTilingExtOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; +} // namespace + +void mlir::ttx::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + // clang-format off + registry.addExtension(+[](MLIRContext *ctx, ttx::TritonTilingExtDialect *dialect) { + TritonTilingExtOpInterfaceHelper< + ttx::CumSumOp + >::registerOpInterface(ctx); + }); + // clang-format on +} diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt new file mode 100644 index 000000000..b6b07162c --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(TritonTilingExtIR + BufferizableOpInterfaceImpl.cpp + CumSum.cpp + TritonTilingExtDialect.cpp + + DEPENDS + TritonTilingExtInterfacesIncGen + TritonTilingExtOpsIncGen + + LINK_LIBS PUBLIC + TritonIR + MLIRAffineAnalysis + MLIRFuncDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgUtils + ) diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp new file mode 100644 index 000000000..619b653ea --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/CumSum.cpp @@ -0,0 +1,112 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// This file implements cumulative sum (CumSum) using the TilingInterface. Only +// supports tensors of rank 1 & 2 and axis == rank - 1 (i.e: we can split the +// computation of each row and compute them independently). The semantics of +// tiling for other axes are more complex and require working with +// non-contiguous memory. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "ttx-cumsum" + +using namespace mlir; +using namespace mlir::ttx; + +void ttx::CumSumOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input, IntegerAttr axis, Value output, + ArrayRef attributes) { + SmallVector inputs{input}; + SmallVector outputs{output}; + odsState.addOperands(inputs); + odsState.addOperands(outputs); + odsState.addAttribute( + "operand_segment_sizes", + odsBuilder.getDenseI32ArrayAttr({static_cast(inputs.size()), + static_cast(outputs.size())})); + + odsState.addAttribute(getAxisAttrStrName(), axis); + odsState.addAttributes(attributes); + odsState.addTypes(SmallVector{output.getType()}); +} + +mlir::LogicalResult ttx::CumSumOp::verify() { + auto inputType = getInput().getType(); + if (!isa(inputType) && !isa(inputType)) { + return emitOpError( + "CumSum op expects input to be either tensor or memref."); + } + + auto outputType = getOutput().getType(); + if (!isa(outputType) && !isa(outputType)) { + return emitOpError( + "CumSum op expects output to be either tensor or memref."); + } + + if (dyn_cast(inputType).getShape() != + dyn_cast(outputType).getShape()) { + return emitOpError("Input and output types must be the same."); + } + + int64_t rank = getRank(); + if (rank != 1 && rank != 2) { + return emitOpError("CumSum op only takes tensors of rank 1 & 2."); + } + + int64_t axis = getAxis(); + if (axis != rank - 1) { + return emitOpError("CumSum computation only supports axis == rank - 1"); + } + + return success(); +} + +AffineMap ttx::CumSumOp::getInputIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index == 0); + return AffineMap::getMultiDimIdentityMap(getRank(), context); +} + +AffineMap ttx::CumSumOp::getOutputIndexingMap(MLIRContext *context, + unsigned int index, + ArrayRef sizes) { + assert(index == 0); + return AffineMap::getMultiDimIdentityMap(getRank(), context); +} + +SmallVector ttx::CumSumOp::getLoopIteratorTypes() { + SmallVector iterators; + iterators.append(getRank() - 1, utils::IteratorType::parallel); + iterators.push_back(utils::IteratorType::reduction); + return iterators; +} + +SmallVector ttx::CumSumOp::getIterationDomain(OpBuilder &b) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(*this); + auto loc = getLoc(); + auto zero = b.getIndexAttr(0); + auto one = b.getIndexAttr(1); + SmallVector iterationDomain; + + // Return the bounds for all dimensions. The caller is responsible for not + // tiling the inner most dimension, otherwise the semantic of the resulting op + // is incorrect. + for (auto i = 0; i < getRank(); i++) { + OpFoldResult upperbound = linalg::createFoldedDimOp(b, loc, getInput(), i); + iterationDomain.push_back(Range{zero, upperbound, one}); + } + return iterationDomain; +} diff --git a/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp new file mode 100644 index 000000000..da4d97674 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp @@ -0,0 +1,404 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Value.h" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ttx; +using namespace mlir::linalg; + +namespace mlir { +namespace ttx { + +Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Case([&](MemRefType type) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Default([&](Type t) { return nullptr; }); +} + +// +// getTiledImplementation +// +// Given an array of offsets and sizes, return the corresponding tiled version +// of the current op. +// +// This method is responsible for creating the extract slice ops for each +// operand of the op (including input and output operand). +// +// As an example, assuming we tile a linalg.matmul ins(%0, %1) out(%out) +// +// This method then generate: +// +// %in_slice_0 = extract_slice from %0 +// %in_slice_1 = extract_slice from %1 +// %out_slice = extract_slice from %out +// %tile = linalg.matmul ins(%in_slice_0, %in_slice_1) out(%out_slice) +// +// To generate these extract slice, we go over each operand, get the +// corresponding affine map to compute the correct offsets and sizes. +// +// Now let's describe how we compute the correct offsets and sizes from +// an affine map. +// +// - Offsets: +// An affine map describes how to access a tensor (i.e: the indicies into a +// tensor), so getting the offsets (also indices) from an affine map is just +// simply "applying" the sub-map on the offset (calling +// makeComposedFoldedAffineApply which also does constant folding +// automatically). +// +// For example: +// Let's assume we have the following nested loops: +// for i in range(0, 10): +// for j in range(0, 20): +// for k in range(0, 30): +// dst[i][j][k] = src[i * 2][j + k] +// +// Assume that we describe the iteration space based on dst. So: +// - dst's affine map is (d0, d1, d2) -> (d0, d1, d2) +// - src's affine map is (d0, d1, d2) -> (d0 * 2, d1 + d2) +// +// Now let's say we want to tile the operator with offset (0, 1, 2). +// +// For dst, we apply this (0, 1, 2) to its affine map and get (0, 1, 2) +// +// For src, we have to plug in the offsets into the affine map to get: +// +// (0 * 2, 1 + 2) = (0, 3) +// +// This is exactly what the implementation does as well. +// The call to getSubMap gets the i'th result expression, then the call to +// makeComposedFoldedAffineApply apply the `offsets` array to the i'th result +// expression in the affine map. +// +// +// - Sizes: +// Size is slightly more complex, notice that there are 3 steps to compute +// sizes: +// +// 1) call linalg::computeTileSizes on the provided `sizes` +// 2) apply the affine map +// 3) add 1 to the result +// +// The reason for this complexity is because the affine maps describe indices +// iteration space with a half open interval (i.e.: we always from 0 until right +// before the upper bound). So if we simply apply the affine map on the sizes, +// we will have incorrect results. +// +// Consider this snippet again: +// for i in range(0, 16): +// for j in range(0, 32): +// for k in range(0, 64): +// dst[i][j][k] = src[i * 2][j + k] +// +// Assume we want the operator to have a tile size of (16, 32, 64) -- so no +// tiling at all. If we apply the affine map of src (d0, d1, d2) -> (d0 * 2, d1 +// + d2), we have +// +// (16 * 2, 32 + 64) -> (32, 96) +// +// However, consider the second dimension of source: +// - j goes from 0 till 31 inclusive +// - k goes from 0 till 63 inclusive +// +// So the max index of src's second dimension is 31 + 63 = 94. Since index +// starts from 0, this means the second dimension has 95 elements. But the +// formula gives us a tile size of 96!!! The same argument can be applied for +// the first dimension as well, the number of elements is 15 * 2 + 1 = 31, but +// computed tile size is 32. +// +// So simply applying the indexing map to compute tile size is INCORRECT!! +// This happens because the indexing map operates on [0, size), while tile sizes +// are inclusive. +// +// The correct formula is: +// ((d0 - 1) * 2 + 1), (d1 - 1) + (d2 - 1) + 1 which gives +// (15 * 2 + 1, 32 - 1 + 64 - 1 + 1) -> (31, 95) +// +// So again, the steps are: +// - Subtract 1 from the sizes (what linalg::computeTileSizes does) +// - Apply the affine map +// - Add 1 to the result +// +template +FailureOr getTiledImplementation(TritonTilingExtOpTy op, + OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) { + Location loc = op->getLoc(); + SmallVector valuesToTile = op->getOperands(); + SmallVector tiledValues; + auto oneAttr = b.getI64IntegerAttr(1); + + for (OpOperand &opOperand : op->getOpOperands()) { + unsigned int index = opOperand.getOperandNumber(); + auto val = valuesToTile[index]; + auto type = dyn_cast(val.getType()); + + if (!type) { + tiledValues.push_back(val); + continue; + } + + auto rank = type.getRank(); + SmallVector newOffsets; + SmallVector newSizes; + SmallVector newStrides(rank, oneAttr); + + llvm::SmallVector composedTileSizes = + linalg::computeTileSizes(b, loc, sizes, {}); + + AffineMap map = op.getIndexingMap(b.getContext(), index, sizes); + for (int64_t i = 0; i < rank; i++) { + AffineMap m = map.getSubMap(i); + { + OpFoldResult upperboundClosed = + affine::makeComposedFoldedAffineApply(b, loc, m, composedTileSizes); + AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + b, loc, s0 + 1, upperboundClosed); + newSizes.push_back(size); + } + { + OpFoldResult offset = + affine::makeComposedFoldedAffineApply(b, loc, m, offsets); + newOffsets.push_back(offset); + } + } + + tiledValues.push_back( + getSlice(b, loc, val, newOffsets, newSizes, newStrides)); + } + + SmallVector resultTensorTypes = llvm::to_vector( + llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { + return tiledValues[opOperand.getOperandNumber()].getType(); + })); + + Operation *tiledOp = clone(b, op, resultTensorTypes, tiledValues); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + +// +// getResultTilePosition +// This method returns the resultOffsets and resultSizes through references +// for the tiled operator. While `getTiledImplementation` is responsible for +// generating the extract slice for all operands, `getResultTilePosition` is +// responsible for returning the offsets and sizes which the tiling engine will +// then use to generate the corresponding insert_slice ops. +// +// Because the slice we insert back to the output tensor is the same as the +// slice that we extracted from the output tensor, this method just repeats the +// offset and size computation in `getTiledImplementation`. +// +template +LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + Location loc = op.getLoc(); + + AffineMap outputMap = + op.getOutputIndexingMap(b.getContext(), resultNumber, sizes); + + Value result = op.getDpsInitOperand(resultNumber)->get(); + auto rank = dyn_cast(result.getType()).getRank(); + + llvm::SmallVector composedTileSizes = + linalg::computeTileSizes(b, loc, sizes, {}); + for (int64_t i = 0; i < rank; i++) { + AffineMap m = outputMap.getSubMap(i); + { + OpFoldResult upperboundClosed = + affine::makeComposedFoldedAffineApply(b, loc, m, composedTileSizes); + AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + b, loc, s0 + 1, upperboundClosed); + resultSizes.push_back(size); + } + { + OpFoldResult offset = + affine::makeComposedFoldedAffineApply(b, loc, m, offsets); + resultOffsets.push_back(offset); + } + } + + return success(); +} + +// This method is borrowed verbatim from +// mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +// +// This is invoked when the current op produces a result that is used +// as an input to another op that is being tiled. The method essentially handles +// producing a new op where the result matches the given offsets and sizes. +// If the method succeeds, the two new operators will be fused in the same loop. +// +// As an example, consider the following IR where the linalg.generic is being +// tiled (unnecessary detailed omitted for brevity): +// +// clang-format: off +// +// func.func @some_op_1( +// %arg0: tensor<8x2x256x512xbf16>, +// %arg1: tensor<8x256x1024xbf16> +// ) -> tensor<8x256x1024xbf16> +// %1 = linalg.init_tensor [8, 256, 1024] : tensor<8x256x1024xbf16> +// %2 = linalg.init_tensor [8, 256, 1024] : tensor<8x256x1024xbf16> +// %3 = ttx.some_op +// ins(%arg0 : tensor<8x2x256x512xbf16>) +// outs(%1 : tensor<8x256x1024xbf16>) -> tensor<8x256x1024xbf16> +// %4 = linalg.generic +// ins(%3, %arg1 : tensor<8x256x1024xbf16>, tensor<8x256x1024xbf16>) +// outs(%2 : tensor<8x256x1024xbf16>) { +// ^bb0(%arg2: bf16, %arg3: bf16, %arg4: bf16): +// %add = arith.addf %arg2, %arg3 : bf16 +// linalg.yield %add : bf16 +// } -> tensor<8x256x1024xbf16> +// return %4 : tensor<8x256x1024xbf16> +// } +// +// clang-format: on +// +// We tile linalg.generic, but one of its inputs is %3 which is the result of +// ttx.some_op. So the tiling engine will invoke +// generateResultTileValue of ttx.some_op to determine if it's +// possible to create a tiled version of it, thereby making it possible to fuse +// both operators together in a loop. +template +FailureOr +generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b, + unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes) { + + // Check that the indexing map used for the output is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the result to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = op.getOutputIndexingMap(b.getContext(), 0, sizes); + if (!indexingMap.isProjectedPermutation()) { + return op.emitOpError( + "unhandled tiled implementation generation when result is not " + "accessed using a permuted projection"); + } + + auto numLoops = op.getLoopIteratorTypes().size(); + SmallVector iterationTileOffsets(numLoops), + iterationTileSizes(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = op.getIterationDomain(b); + for (auto range : llvm::enumerate(iterationDomain)) { + iterationTileOffsets[range.index()] = range.value().offset; + iterationTileSizes[range.index()] = range.value().size; + } + } + for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) { + assert(resultExpr.value().getKind() == AffineExprKind::DimId); + // HACK: LLVM casting utilities do not work here for out-of-tree builds, + // as there is no template specialization for this cast in the base + // build. + AffineDimExpr affineDimExpr(static_cast( + const_cast(resultExpr.value().getAsOpaquePointer()))); + unsigned dimPosition = affineDimExpr.getPosition(); + iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; + iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; + } + + FailureOr tilingResult = + op.getTiledImplementation(b, iterationTileOffsets, iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) + return op.emitOpError("failed to generate tiled implementation"); + + return TilingResult{ + tilingResult->tiledOps, + SmallVector{tilingResult->tiledValues[resultNumber]}}; +} + +// This method is borrowed directly from linalg.generic's implementation +// in mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +// This marks all operands that are part of the input group to have read +// effect, while all other operands that are part of the output group +// to have both read and write effects. +static void getTritonTilingExtEffectsImpl( + SmallVectorImpl> + &effects, + ValueRange results, ArrayRef inputOperands, + const MutableOperandRange &outputOperands) { + for (auto operand : inputOperands) { + if (!llvm::isa(operand->get().getType())) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + for (auto &operand : outputOperands) { + if (!llvm::isa(operand.get().getType())) + continue; + + effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } +} + +template +void getEffects( + TritonTilingExtOpTy op, + SmallVectorImpl> + &effects) { + getTritonTilingExtEffectsImpl(effects, op.getOperation()->getResults(), + op.getDpsInputOperands(), + op.getDpsInitsMutable()); +} + +} // namespace ttx +} // namespace mlir + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void TritonTilingExtDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.cpp.inc" + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.cpp.inc" + +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt new file mode 100644 index 000000000..58039db5f --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(Tx81IR + IR/Tx81Dialect.cpp + IR/Tx81Ops.cpp + + DEPENDS + Tx81TableGen + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp new file mode 100644 index 000000000..d819a4e17 --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Dialect.cpp @@ -0,0 +1,30 @@ +//===-------------------------- Tx81Dialect.cpp ---------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +using namespace mlir; +using namespace mlir::tx; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void Tx81Dialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tsingmicro-tx81/Dialect/IR/Tx81Enums.cpp.inc" +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.cpp.inc" + +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.cpp.inc" diff --git a/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp new file mode 100644 index 000000000..9db877dce --- /dev/null +++ b/third_party/tsingmicro/lib/Dialect/TsingMicroTx81/IR/Tx81Ops.cpp @@ -0,0 +1,10 @@ +//===-------------------------- Tx81Ops.cpp -------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h" +using namespace mlir; +using namespace mlir::tx; diff --git a/third_party/tsingmicro/name.conf b/third_party/tsingmicro/name.conf new file mode 100644 index 000000000..0271c2eb3 --- /dev/null +++ b/third_party/tsingmicro/name.conf @@ -0,0 +1 @@ +ztc \ No newline at end of file diff --git a/third_party/tsingmicro/python/triton_tsingmicro.cc b/third_party/tsingmicro/python/triton_tsingmicro.cc new file mode 100644 index 000000000..ff232d947 --- /dev/null +++ b/third_party/tsingmicro/python/triton_tsingmicro.cc @@ -0,0 +1,7 @@ +#include + +namespace py = pybind11; + +// The TsingMicro backend with ztc doesn't do compilation from within python +// but rather externally through ztc-opt, so we leave this function blank. +void init_triton_tsingmicro(py::module &&m) {} From 9925e48f6ed1c17f4e5eb7f7fe67cc3594129e98 Mon Sep 17 00:00:00 2001 From: zhzhcookie Date: Mon, 12 May 2025 10:23:15 +0800 Subject: [PATCH 02/12] [BACKEND] fix tsingmicro code format (#1) --- third_party/tsingmicro/backend/compiler.py | 124 +- third_party/tsingmicro/backend/cpu_driver.py | 100 +- third_party/tsingmicro/backend/driver.cpp | 478 +++---- third_party/tsingmicro/backend/driver.py | 103 +- third_party/tsingmicro/crt/CMakeLists.txt | 2 +- .../tsingmicro/crt/gcc_flash_xiaohui.ld | 2 +- .../crt/include/Tx81/instr_adapter.h | 24 +- .../crt/include/Tx81/instr_adapter_plat.h | 1128 +++++++++++------ .../tsingmicro/crt/include/Tx81/instr_def.h | 1035 +++++++-------- .../crt/include/Tx81/runtime/hrt_common.h | 928 +++++++------- .../crt/include/Tx81/runtime/hrt_interface.h | 74 +- third_party/tsingmicro/crt/lib/Tx81/argmax.c | 11 +- third_party/tsingmicro/crt/lib/Tx81/argmin.c | 11 +- third_party/tsingmicro/crt/lib/Tx81/arith.c | 55 +- .../tsingmicro/crt/lib/Tx81/bf16_fp16.c | 10 +- .../tsingmicro/crt/lib/Tx81/bf16_fp32.c | 8 +- .../tsingmicro/crt/lib/Tx81/bf16_int16.c | 10 +- .../tsingmicro/crt/lib/Tx81/bf16_int32.c | 10 +- .../tsingmicro/crt/lib/Tx81/bf16_int8.c | 8 +- .../tsingmicro/crt/lib/Tx81/bf16_tf32.c | 8 +- .../tsingmicro/crt/lib/Tx81/bilinear.c | 23 +- third_party/tsingmicro/crt/lib/Tx81/bit2fp.c | 12 +- third_party/tsingmicro/crt/lib/Tx81/common.c | 12 +- third_party/tsingmicro/crt/lib/Tx81/concat.c | 24 +- third_party/tsingmicro/crt/lib/Tx81/conv.c | 39 +- third_party/tsingmicro/crt/lib/Tx81/cos.c | 10 +- third_party/tsingmicro/crt/lib/Tx81/count.c | 15 +- third_party/tsingmicro/crt/lib/Tx81/exp.c | 8 +- third_party/tsingmicro/crt/lib/Tx81/explp.c | 8 +- .../tsingmicro/crt/lib/Tx81/fp16_bf16.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp16_fp32.c | 8 +- .../tsingmicro/crt/lib/Tx81/fp16_int16.c | 8 +- .../tsingmicro/crt/lib/Tx81/fp16_int32.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp16_int8.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp16_tf32.c | 8 +- .../tsingmicro/crt/lib/Tx81/fp32_bf16.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp32_fp16.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp32_int16.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp32_int32.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp32_int8.c | 10 +- .../tsingmicro/crt/lib/Tx81/fp32_tf32.c | 10 +- .../tsingmicro/crt/lib/Tx81/gatherscatter.c | 21 +- third_party/tsingmicro/crt/lib/Tx81/gemm.c | 40 +- third_party/tsingmicro/crt/lib/Tx81/img2col.c | 20 +- .../tsingmicro/crt/lib/Tx81/int16_bf16.c | 10 +- .../tsingmicro/crt/lib/Tx81/int16_fp16.c | 8 +- .../tsingmicro/crt/lib/Tx81/int16_fp32.c | 10 +- .../tsingmicro/crt/lib/Tx81/int16_tf32.c | 10 +- .../tsingmicro/crt/lib/Tx81/int32_bf16.c | 10 +- .../tsingmicro/crt/lib/Tx81/int32_fp16.c | 10 +- .../tsingmicro/crt/lib/Tx81/int32_fp32.c | 10 +- .../tsingmicro/crt/lib/Tx81/int32_tf32.c | 10 +- .../tsingmicro/crt/lib/Tx81/int8_bf16.c | 11 +- .../tsingmicro/crt/lib/Tx81/int8_fp16.c | 11 +- .../tsingmicro/crt/lib/Tx81/int8_fp32.c | 11 +- .../tsingmicro/crt/lib/Tx81/int8_tf32.c | 11 +- .../tsingmicro/crt/lib/Tx81/leakyrelu.c | 13 +- third_party/tsingmicro/crt/lib/Tx81/ln.c | 8 +- third_party/tsingmicro/crt/lib/Tx81/log2.c | 8 +- third_party/tsingmicro/crt/lib/Tx81/lut16.c | 14 +- third_party/tsingmicro/crt/lib/Tx81/lut32.c | 14 +- .../tsingmicro/crt/lib/Tx81/mask_move.c | 10 +- third_party/tsingmicro/crt/lib/Tx81/memset.c | 12 +- third_party/tsingmicro/crt/lib/Tx81/mirror.c | 20 +- .../tsingmicro/crt/lib/Tx81/nchw2nhwc.c | 20 +- .../tsingmicro/crt/lib/Tx81/nhwc2nchw.c | 20 +- third_party/tsingmicro/crt/lib/Tx81/pad.c | 24 +- third_party/tsingmicro/crt/lib/Tx81/pow2.c | 8 +- third_party/tsingmicro/crt/lib/Tx81/randgen.c | 13 +- third_party/tsingmicro/crt/lib/Tx81/relu.c | 8 +- .../tsingmicro/crt/lib/Tx81/rotate180.c | 20 +- .../tsingmicro/crt/lib/Tx81/rotate270.c | 20 +- .../tsingmicro/crt/lib/Tx81/rotate90.c | 19 +- third_party/tsingmicro/crt/lib/Tx81/satrelu.c | 14 +- third_party/tsingmicro/crt/lib/Tx81/sigmoid.c | 14 +- third_party/tsingmicro/crt/lib/Tx81/sin.c | 8 +- .../tsingmicro/crt/lib/Tx81/softplus.c | 14 +- third_party/tsingmicro/crt/lib/Tx81/tanh.c | 8 +- .../tsingmicro/crt/lib/Tx81/tensornorm.c | 20 +- .../tsingmicro/crt/lib/Tx81/tf32_bf16.c | 10 +- .../tsingmicro/crt/lib/Tx81/tf32_fp16.c | 8 +- .../tsingmicro/crt/lib/Tx81/tf32_fp32.c | 8 +- .../tsingmicro/crt/lib/Tx81/tf32_int16.c | 10 +- .../tsingmicro/crt/lib/Tx81/tf32_int32.c | 10 +- .../tsingmicro/crt/lib/Tx81/tf32_int8.c | 10 +- .../tsingmicro/crt/lib/Tx81/transpose.c | 20 +- .../include/ExecutionEngine/CRunnerUtils.cpp | 7 +- .../include/ExecutionEngine/CRunnerUtils.h | 51 +- .../Dialect/IR/MagicKernelFuncOps.td | 2 +- .../Dialect/IR/MagicKernelInstrOps.td | 2 +- .../Conversion/LinalgToMK/LinalgToMK.h | 7 +- .../Dialect/IR/MagicKernelAttrDefs.td | 2 +- .../Dialect/IR/MagicKernelDialect.h | 4 +- .../Dialect/IR/MagicKernelDialect.td | 2 +- .../magic-kernel/Dialect/IR/MagicKernelOps.td | 2 +- .../Dialect/IR/MagicKernelTypes.td | 2 +- .../Analysis/OpFoldResultUtils.h | 8 +- .../TritonArithToLinalg/ConversionPatterns.h | 111 +- .../Conversion/MKToTx81/MKToTx81.h | 9 +- .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.h | 2 +- .../Tx81ToLLVM/KernelArgBufferPass.h | 4 +- .../Dialect/IR/Tx81AttrDefs.td | 2 +- .../tsingmicro-tx81/Dialect/IR/Tx81Dialect.h | 2 +- .../tsingmicro-tx81/Dialect/IR/Tx81Ops.h | 2 +- .../tsingmicro-tx81/Dialect/IR/Tx81Ops.td | 2 +- .../tsingmicro-tx81/Dialect/IR/Tx81Types.td | 2 +- .../tsingmicro/lib/Analysis/MaskAnalysis.cpp | 14 +- .../lib/Analysis/OpFoldResultUtils.cpp | 42 +- .../lib/AnalysisStructured/PtrAnalysis.cpp | 4 +- .../lib/Conversion/LinalgToMK/LinalgToMK.cpp | 8 +- .../Conversion/LinalgToMK/LinalgToMKPass.cpp | 32 +- .../lib/Conversion/MKToTx81/MKToTx81Pass.cpp | 5 +- .../StructuredToMemrefPass.cpp | 18 +- .../TritonArithToLinalgPass.cpp | 2 +- .../TritonToCoreDialectsPass.cpp | 2 +- .../TritonToStructuredPass.cpp | 4 +- .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp | 2 +- .../Tx81ToLLVM/KernelArgBufferPass.cpp | 22 +- .../lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp | 3 +- .../Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp | 28 +- .../BufferizableOpInterfaceImpl.cpp | 2 +- third_party/tsingmicro/name.conf | 2 +- 122 files changed, 3113 insertions(+), 2304 deletions(-) diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py index 919961eaf..16dbd8c9f 100644 --- a/third_party/tsingmicro/backend/compiler.py +++ b/third_party/tsingmicro/backend/compiler.py @@ -12,24 +12,28 @@ import functools from pathlib import Path + def _get_ztc_opt_path() -> str: path = os.getenv("ZTC_OPT_PATH", "") if path == "": raise Exception("ZTC_OPT_PATH is not set.") return path + def _get_vendor_runtime_path() -> str: path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") if path == "": raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") return path + def _get_llvm_bin_path(bin_name: str) -> str: path = os.getenv("LLVM_BINARY_DIR", "") if path == "": raise Exception("LLVM_BINARY_DIR is not set.") return os.path.join(path, bin_name) + # The riscv c header files and libraries path. def _get_libc_root() -> str: path = os.getenv("LIB_C_ROOT", "") @@ -37,6 +41,7 @@ def _get_libc_root() -> str: raise Exception("LIB_C_ROOT is not set.") return path + def _dump_ir_if_needed(files): path = os.getenv("ZTC_DUMP_PATH", "") if not path: @@ -56,12 +61,11 @@ def _ttir_to_coreir(mod): Path(src_path).write_text(ttir_code) ztc_opt_path = _get_ztc_opt_path() _dump_ir_if_needed([src_path]) - subprocess.check_call([ztc_opt_path, src_path, - "--triton-to-core-dialects", - "--one-shot-bufferize", + subprocess.check_call([ + ztc_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize", #"--mlir-print-debuginfo", - "-o", - dst_path]) + "-o", dst_path + ]) return Path(dst_path).read_text() @@ -79,11 +83,11 @@ def _coreir_to_mkir(mod): Path(src_path).write_text(coreir_code) ztc_opt_path = _get_ztc_opt_path() _dump_ir_if_needed([src_path]) - subprocess.check_call([ztc_opt_path, src_path, - "--core-dialects-to-mk", + subprocess.check_call([ + ztc_opt_path, src_path, "--core-dialects-to-mk", #"--mlir-print-debuginfo", - "-o", - dst_path]) + "-o", dst_path + ]) return Path(dst_path).read_text() @@ -101,16 +105,16 @@ def _coreir_to_txir(mod): Path(src_path).write_text(coreir_code) ztc_opt_path = _get_ztc_opt_path() _dump_ir_if_needed([src_path]) - subprocess.check_call([ztc_opt_path, src_path, - "--expand-strided-metadata", - "--mk-to-tx81", - "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load - "--cse", # unused memref.subview/memref.reinterpret + subprocess.check_call([ + ztc_opt_path, src_path, "--expand-strided-metadata", "--mk-to-tx81", + "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load + "--cse", # unused memref.subview/memref.reinterpret #"--mlir-print-debuginfo", - "-o", - dst_path]) + "-o", dst_path + ]) return Path(dst_path).read_text() + def _optimize_txir(txir: str): # We don't apply any optimizations now, but we can add passes if needed. return txir @@ -126,30 +130,22 @@ def _txir_to_llir(mod): ztc_opt_path = _get_ztc_opt_path() _dump_ir_if_needed([src_path]) # Tx81 and core dialects to LLVM-MLIR - subprocess.check_call([ztc_opt_path, src_path, - "--tx81-memref-to-llvm", - "--tx81-to-llvm", - "--convert-scf-to-cf", - "--convert-math-to-llvm", - "--convert-func-to-llvm", - "--convert-cf-to-llvm", + subprocess.check_call([ + ztc_opt_path, src_path, "--tx81-memref-to-llvm", "--tx81-to-llvm", "--convert-scf-to-cf", + "--convert-math-to-llvm", "--convert-func-to-llvm", "--convert-cf-to-llvm", # Use tx81-memref-to-llvm custom pass for now. # "--finalize-memref-to-llvm", - "--convert-arith-to-llvm", # need exec last since arith.const conversion + "--convert-arith-to-llvm", # need exec last since arith.const conversion # Remove all unrealized casts created - "--reconcile-unrealized-casts", - "--canonicalize", + "--reconcile-unrealized-casts", "--canonicalize", #"--mlir-print-debuginfo", - "-o", - llvmir_path]) + "-o", llvmir_path + ]) _dump_ir_if_needed([llvmir_path]) # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") - subprocess.check_call([mlir_translate_path, llvmir_path, - "--mlir-to-llvmir", - "-o", - llir_path]) + subprocess.check_call([mlir_translate_path, llvmir_path, "--mlir-to-llvmir", "-o", llir_path]) _dump_ir_if_needed([llir_path]) return Path(llir_path).read_text() @@ -163,47 +159,30 @@ def _mkir_to_llir(mkir: str): Path(mkir_path).write_text(mkir) mlir_opt_path = _get_llvm_bin_path("mlir-opt") # MagicKernel-MLIR to LLVM-MLIR - subprocess.check_call([mlir_opt_path, mkir_path, - "--convert-linalg-to-affine-loops", + subprocess.check_call([ + mlir_opt_path, mkir_path, "--convert-linalg-to-affine-loops", # Note: eliminate-empty-tensors fails when there are multiple func.return ops # in a single kernel which are the results of early returns. # See python/examples/test_early_return.py for examples. # We disable this pass for now since performance on CPU isn't the main # focus at the moment. # "--eliminate-empty-tensors", - "--empty-tensor-to-alloc-tensor", - "--one-shot-bufferize=allow-return-allocs-from-loops=true", - "--lower-affine", - "--convert-linalg-to-loops", - "--expand-strided-metadata", - "--convert-scf-to-cf", - "--convert-arith-to-llvm", - "--convert-math-to-llvm", - "--convert-complex-to-llvm", - "--convert-vector-to-llvm", - "--convert-index-to-llvm", - "--memref-expand", - "--finalize-memref-to-llvm", - "--convert-func-to-llvm", - "--convert-cf-to-llvm", + "--empty-tensor-to-alloc-tensor", "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", "--convert-linalg-to-loops", "--expand-strided-metadata", "--convert-scf-to-cf", + "--convert-arith-to-llvm", "--convert-math-to-llvm", "--convert-complex-to-llvm", + "--convert-vector-to-llvm", "--convert-index-to-llvm", "--memref-expand", "--finalize-memref-to-llvm", + "--convert-func-to-llvm", "--convert-cf-to-llvm", # Lowering memrefs creates more affine.apply ops. # Lowering these affine ops again creates further arith ops, # so we have to run these two passes again here. - "--lower-affine", - "--convert-arith-to-llvm", + "--lower-affine", "--convert-arith-to-llvm", # Remove all unrealized casts created - "--canonicalize", - "--reconcile-unrealized-casts", - "--mlir-print-debuginfo", - "-o", - llvmir_path]) + "--canonicalize", "--reconcile-unrealized-casts", "--mlir-print-debuginfo", "-o", llvmir_path + ]) # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") - subprocess.check_call([mlir_translate_path, llvmir_path, - "--mlir-to-llvmir", - "-o", - llir_path]) + subprocess.check_call([mlir_translate_path, llvmir_path, "--mlir-to-llvmir", "-o", llir_path]) _dump_ir_if_needed([mkir_path, llvmir_path, llir_path]) return Path(llir_path).read_text() @@ -225,14 +204,10 @@ def _llir_to_bin(llir: str, metadata): dst_path = "/tmp/kernel.o" Path(src_path).write_text(llir) clang_path = _get_llvm_bin_path("clang++") - subprocess.check_call([clang_path, src_path, - "-O2", - "-c", - "-fPIC", - "--target=riscv64-unknown-elf", - "-march=rv64imafdc", - "-o", - dst_path]) + subprocess.check_call([ + clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-elf", "-march=rv64imafdc", "-o", + dst_path + ]) _dump_ir_if_needed([dst_path]) with open(dst_path, 'rb') as f: @@ -240,7 +215,6 @@ def _llir_to_bin(llir: str, metadata): return so - @dataclass(frozen=True) class CPUOptions: debug: bool = False @@ -287,15 +261,8 @@ def pack_metadata(self, metadata): # Note: We actually don't need any of these except for the name which is # used in the launch function in driver.py. Putting these in so we're # consistent with other backends - return ( - metadata.num_warps, - metadata.num_ctas, - metadata.shared, - metadata.cluster_dims[0], - metadata.cluster_dims[1], - metadata.cluster_dims[2], - metadata.name - ) + return (metadata.num_warps, metadata.num_ctas, metadata.shared, metadata.cluster_dims[0], + metadata.cluster_dims[1], metadata.cluster_dims[2], metadata.name) # Our compilation pipeline isn't in python like nvidia or amd, no need to load # dialects. See `ztc.cc` @@ -324,7 +291,6 @@ def add_stages(self, stages, options): stages["llir"] = lambda src, metadata: _optimize_llir(_txir_to_llir(src)) stages["so"] = lambda src, metadata: _llir_to_bin(src, metadata) - @functools.lru_cache() def hash(self): return self.target diff --git a/third_party/tsingmicro/backend/cpu_driver.py b/third_party/tsingmicro/backend/cpu_driver.py index b52b6363e..15631354b 100644 --- a/third_party/tsingmicro/backend/cpu_driver.py +++ b/third_party/tsingmicro/backend/cpu_driver.py @@ -12,6 +12,7 @@ from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget + # The riscv compiler def _get_llvm_bin_path() -> str: path = os.getenv("LLVM_BINARY_DIR", "") @@ -19,6 +20,7 @@ def _get_llvm_bin_path() -> str: raise Exception("LLVM_BINARY_DIR is not set.") return path + # The riscv c header files and libraries path. def _get_libc_root() -> str: path = os.getenv("LIB_C_ROOT", "") @@ -49,37 +51,43 @@ def _ty_to_cpp(ty): "fp64": "double", }[ty] + def _extracted_type(ty): if ty[0] == '*': return "PyObject*" return _ty_to_cpp(ty) + def _format_of(ty): return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", }[ty] + def _generate_launcher(constants, signature, kernel_name): arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants) + kernel_arg_decls = ', '.join( + _ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants) kernel_arg_decls += ', ' if kernel_arg_decls else '' - kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants) + kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" + for i, ty in signature.items() + if i not in constants) kernel_parameters += ', ' if kernel_parameters else '' return f""" @@ -246,16 +254,14 @@ def compile_module(launcher_src, kernel_placeholder_name): libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib") include_dir = os.path.join(cpu_backend_path, "include") - def launch( - gridX, gridY, gridZ, stream, cu_function, - kernel_metadata, launch_metadata, - launch_enter_hook, launch_exit_hook, *args): + def launch(gridX, gridY, gridZ, stream, cu_function, kernel_metadata, launch_metadata, launch_enter_hook, + launch_exit_hook, *args): # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. # Let's compile a kernel every time. # The cu_function parameter actually contains our assembly source code. # See CPUUtils.load_binary method. asm_src = cu_function - kernel_name = kernel_metadata[6] # see pack_metadata in compiler.py + kernel_name = kernel_metadata[6] # see pack_metadata in compiler.py src = launcher_src.replace(kernel_placeholder_name, kernel_name) key = hashlib.md5(src.encode("utf-8") + asm_src).hexdigest() @@ -265,31 +271,27 @@ def launch( cache_path = cache.get_file(filename) if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - asm_src_path = os.path.join(tmpdir, "kernel.s") - launcher_src_path = os.path.join(tmpdir, "main.cxx") - so_path = os.path.join(tmpdir, "kernel.so") - Path(asm_src_path).write_bytes(asm_src) - Path(launcher_src_path).write_text(src) - # Compile it together. - subprocess.check_call([ - clang, "-std=c++17", "--target=riscv64-unknown-elf", - launcher_src_path, asm_src_path, f"-I{libc_inc}", - f"-I{py_include_dir}", f"-I{include_dir}", f"-I{libc_lib}", - f"-L{py_lib_dir}", - "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path - ]) - - with open(so_path, "rb") as f: - cache_path = cache.put(f.read(), filename, binary=True) + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_bytes(asm_src) + Path(launcher_src_path).write_text(src) + # Compile it together. + subprocess.check_call([ + clang, "-std=c++17", "--target=riscv64-unknown-elf", launcher_src_path, asm_src_path, + f"-I{libc_inc}", f"-I{py_include_dir}", f"-I{include_dir}", f"-I{libc_lib}", f"-L{py_lib_dir}", + "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path + ]) + + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) # Load and launch the compiled kernel. spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) - return mod.launch(gridX, gridY, gridZ, - kernel_metadata, launch_metadata, - launch_enter_hook, launch_exit_hook, + return mod.launch(gridX, gridY, gridZ, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, *args) return launch @@ -313,8 +315,8 @@ def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) - class CPUUtils(object): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(CPUUtils, cls).__new__(cls) @@ -332,11 +334,8 @@ def __new__(cls): @staticmethod def get_device_properties(device): return { - "max_shared_mem": 2 ** 20, - "multiprocessor_count": None, - "sm_clock_rate": None, - "mem_clock_rate": None, - "mem_bus_width": None + "max_shared_mem": 2**20, "multiprocessor_count": None, "sm_clock_rate": None, "mem_clock_rate": None, + "mem_bus_width": None } # Important note: @@ -345,12 +344,11 @@ def get_device_properties(device): # module every time. @staticmethod def load_binary(name, kernel_asm, shared, device): - return ( - None, # module - kernel_asm, # function - None, # n_regs - None # n_spills - ) + return (None, # module + kernel_asm, # function + None, # n_regs + None # n_spills + ) class CPUDriver(DriverBase): diff --git a/third_party/tsingmicro/backend/driver.cpp b/third_party/tsingmicro/backend/driver.cpp index 8658ecf86..242a4433b 100644 --- a/third_party/tsingmicro/backend/driver.cpp +++ b/third_party/tsingmicro/backend/driver.cpp @@ -8,30 +8,29 @@ // Tx81 platform device side runtime interface for python. // //===----------------------------------------------------------------------===// -#include -#include #include +#include +#include #include #define PY_SSIZE_T_CLEAN #include +#include #include #include -#include struct Kernel_Param // Triton kernel arguments { - uint32_t gridX; - uint32_t gridY; - uint32_t gridZ; - // TODO... + uint32_t gridX; + uint32_t gridY; + uint32_t gridZ; + // TODO... }; -struct Kernel_Head -{ - uint32_t param_type; - uint32_t param_num; - uint32_t param_addr; - uint32_t xxxxx; +struct Kernel_Head { + uint32_t param_type; + uint32_t param_num; + uint32_t param_addr; + uint32_t xxxxx; }; // Raises a Python exception and returns false if code is not RET_SUCCESS. @@ -43,41 +42,41 @@ static bool tsmAssert(TSM_RETCODE code, const char *file, int line) { const char *str; // Map error codes to strings - switch(code) { - case RET_ERROR: - str = "General error"; - break; - case RET_PARAM1_ERROR: - case RET_PARAM2_ERROR: - case RET_PARAM3_ERROR: - str = "Parameter error"; - break; - case RET_DEVICE_OFFLINE: - str = "Device offline"; - break; - case RET_DEVICE_NOMEM: - str = "Device out of memory"; - break; - case RET_DEVICE_IN_IDLE: - str = "Device in idle state"; - break; - case RET_DEVICE_IN_ATTACH: - str = "Device already attached"; - break; - case RET_DEVICE_ATTACH_SUCCESS: - str = "Device attach success"; - break; - case RET_DEVICE_ATTACH_READY: - str = "Device attach ready"; - break; - case RET_DEVICE_LOSE_CONNECT: - str = "Device connection lost"; - break; - case RET_ENV_CLEAN_UP: - str = "Environment cleanup required"; - break; - default: - str = "Unknown error"; + switch (code) { + case RET_ERROR: + str = "General error"; + break; + case RET_PARAM1_ERROR: + case RET_PARAM2_ERROR: + case RET_PARAM3_ERROR: + str = "Parameter error"; + break; + case RET_DEVICE_OFFLINE: + str = "Device offline"; + break; + case RET_DEVICE_NOMEM: + str = "Device out of memory"; + break; + case RET_DEVICE_IN_IDLE: + str = "Device in idle state"; + break; + case RET_DEVICE_IN_ATTACH: + str = "Device already attached"; + break; + case RET_DEVICE_ATTACH_SUCCESS: + str = "Device attach success"; + break; + case RET_DEVICE_ATTACH_READY: + str = "Device attach ready"; + break; + case RET_DEVICE_LOSE_CONNECT: + str = "Device connection lost"; + break; + case RET_ENV_CLEAN_UP: + str = "Environment cleanup required"; + break; + default: + str = "Unknown error"; } char err[1024] = {0}; @@ -90,23 +89,23 @@ static bool tsmAssert(TSM_RETCODE code, const char *file, int line) { return false; } - static void prepare_input(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) -{ + std::shared_ptr chip_info) { for (uint32_t i = 0; i < chip_info->input_num; ++i) { chip_info->input_dev_addr.push_back(0); if (TsmDeviceMalloc(devices[dev_index], chip_info->input_dev_addr[i], - chip_info->input_size[i]) != RET_SUCCESS) { - printf("[Chip id %u] Input%d, DeviceMalloc failed!\n", devices[dev_index]->chip_id, i); + chip_info->input_size[i]) != RET_SUCCESS) { + printf("[Chip id %u] Input%d, DeviceMalloc failed!\n", + devices[dev_index]->chip_id, i); TsmResetDevice(devices[dev_index]); return; } if (TsmMemcpyH2D((TsmDevicePtr)chip_info->input_dev_addr[i], - (void*) chip_info->input_host_addr[i], + (void *)chip_info->input_host_addr[i], chip_info->input_size[i]) != RET_SUCCESS) { - printf("[Chip id %u] Input%d, MemcpyH2D failed!\n", devices[dev_index]->chip_id, i); + printf("[Chip id %u] Input%d, MemcpyH2D failed!\n", + devices[dev_index]->chip_id, i); TsmResetDevice(devices[dev_index]); return; } @@ -114,14 +113,14 @@ static void prepare_input(std::vector devices, uint32_t dev_index, } static void prepare_output(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) { + std::shared_ptr chip_info) { for (size_t i = 0; i < chip_info->output_num; ++i) { chip_info->output_dev_addr.push_back(0); printf("[Chip id %u] output[%lu] data(size: %lu)\n", devices[dev_index]->chip_id, i, chip_info->output_size[i]); if (TsmDeviceMalloc(devices[dev_index], chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) { + chip_info->output_size[i]) != RET_SUCCESS) { printf("[Chip id %u] output[%lu], DeviceMalloc failed!\n", devices[dev_index]->chip_id, i); TsmResetDevice(devices[dev_index]); @@ -130,21 +129,25 @@ static void prepare_output(std::vector devices, uint32_t dev_index, } } -TSM_RETCODE kernel_result_process(std::vector devices, uint32_t dev_index, - std::shared_ptr hostboot, - std::shared_ptr chip_info, - TsmDevicePtr bootpm_dev, std::string case_dir) { +TSM_RETCODE kernel_result_process(std::vector devices, + uint32_t dev_index, + std::shared_ptr hostboot, + std::shared_ptr chip_info, + TsmDevicePtr bootpm_dev, + std::string case_dir) { for (size_t i = 0; i < chip_info->output_num; ++i) { // 动态shape,需要处理真实的output size if (TsmMemcpyD2H(hostboot->get_bootpmbuffer(), bootpm_dev, - hostboot->get_maxlen()) != RET_SUCCESS) { + hostboot->get_maxlen()) != RET_SUCCESS) { return RET_ERROR; } auto out_tensor = hostboot->get_dev_output_tensor_after_run(i); chip_info->output[i]->dim = out_tensor->dim; - std::memcpy(chip_info->output[i]->shape, out_tensor->shape, sizeof(out_tensor->shape)); - chip_info->output_size[i] = hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); + std::memcpy(chip_info->output[i]->shape, out_tensor->shape, + sizeof(out_tensor->shape)); + chip_info->output_size[i] = + hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); for (uint32_t j = 0; j < out_tensor->dim; ++j) { if (out_tensor->shape[j] > 0) { chip_info->output_size[i] *= out_tensor->shape[j]; @@ -153,14 +156,14 @@ TSM_RETCODE kernel_result_process(std::vector devices, uint32_t dev TsmHostPtr output_host_addr = (TsmHostPtr)malloc(chip_info->output_size[i]); if (chip_info->output_size[i] > 0) { - if (TsmMemcpyD2H((void*)output_host_addr, chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) { + if (TsmMemcpyD2H((void *)output_host_addr, chip_info->output_dev_addr[i], + chip_info->output_size[i]) != RET_SUCCESS) { return RET_ERROR; } } printf("[Chip id %u] output_dev_addr=%ld\n", devices[dev_index]->chip_id, - chip_info->output_dev_addr[i]); + chip_info->output_dev_addr[i]); // TODO: Processing output #if 0 @@ -189,7 +192,7 @@ TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) { } static void setHostBoot(std::shared_ptr &chip_info, - std::shared_ptr &hostboot) { + std::shared_ptr &hostboot) { if (chip_info == nullptr) { printf("chip_info is null.\n"); return; @@ -201,22 +204,24 @@ static void setHostBoot(std::shared_ptr &chip_info, } for (size_t i = 0; i < chip_info->input_dev_addr.size(); ++i) { - hostboot->set_dev_input(i, chip_info->input_dev_addr[i], chip_info->input_size[i]); + hostboot->set_dev_input(i, chip_info->input_dev_addr[i], + chip_info->input_size[i]); hostboot->set_dev_input_tensor(i, chip_info->input[i]); } for (size_t i = 0; i < chip_info->output_dev_addr.size(); ++i) { - hostboot->set_dev_output(i, chip_info->output_dev_addr[i], chip_info->output_size[i]); + hostboot->set_dev_output(i, chip_info->output_dev_addr[i], + chip_info->output_size[i]); } for (size_t i = 0; i < chip_info->param_num; ++i) { - hostboot->set_dev_param(i, chip_info->param_dev_addr[i], chip_info->param_size[i]); + hostboot->set_dev_param(i, chip_info->param_dev_addr[i], + chip_info->param_size[i]); } return; } - // To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. #define TSM_CHECK_AND_RETURN_NULL(ans) \ do { \ @@ -234,27 +239,28 @@ static void setHostBoot(std::shared_ptr &chip_info, } while (0) // Global state for Tx81 devices -static std::vector g_tx81_devices; +static std::vector g_tx81_devices; static bool g_runtime_initialized = false; // Initialize the Tx81 runtime if not already initialized static bool init_tx81_runtime_if_needed() { if (g_runtime_initialized) { - return true; + return true; } // Initialize the Tx81 runtime if (TsmInitRuntime() != RET_SUCCESS) { - PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); - return false; + PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); + return false; } // Get device count uint32_t device_num = 0; if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to get Tx81 device count or no devices found"); - TsmDeInitRuntime(); - return false; + PyErr_SetString(PyExc_RuntimeError, + "Failed to get Tx81 device count or no devices found"); + TsmDeInitRuntime(); + return false; } // Set up devices - for simplicity, we're using a 1x1 configuration @@ -262,10 +268,11 @@ static bool init_tx81_runtime_if_needed() { uint32_t card_x = 1; uint32_t card_y = 1; - if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != RET_SUCCESS) { - PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); - TsmDeInitRuntime(); - return false; + if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != + RET_SUCCESS) { + PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); + TsmDeInitRuntime(); + return false; } g_runtime_initialized = true; @@ -302,9 +309,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { // Extract device properties // Note: We're mapping Tx81 properties to fields expected by Triton int max_shared_mem = 1024 * 1024 * 4; // Default 4MB - //int multiprocessor_count = device->tile_num; + // int multiprocessor_count = device->tile_num; int multiprocessor_count = 1; - int sm_clock_rate = 1000; // Placeholder + int sm_clock_rate = 1000; // Placeholder int mem_clock_rate = 2000; // Placeholder int mem_bus_width = 256; // Placeholder @@ -316,12 +323,11 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { } #endif - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", - "max_shared_mem", max_shared_mem, - "multiprocessor_count", multiprocessor_count, - "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, - "mem_bus_width", mem_bus_width); + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "multiprocessor_count", + multiprocessor_count, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); } static PyObject *loadBinary(PyObject *self, PyObject *args) { @@ -408,41 +414,41 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { int32_t n_regs = 256; int32_t n_spills = 0; // Return values to Python including module, function, n_regs, n_spills - return Py_BuildValue("(KKii)", "module {}", "void @add_kernel() {}", n_regs, n_spills); + return Py_BuildValue("(KKii)", "module {}", "void @add_kernel() {}", n_regs, + n_spills); } - - -static PyObject *launch(PyObject *self, PyObject* args) { +static PyObject *launch(PyObject *self, PyObject *args) { std::vector devices; // TODO:通过参数传递获取device信息 - - // 需要的输入信息: devices, case_dir(按固定路径存放的kernelso), input_host_addr/input_size/input_num, + // 需要的输入信息: devices, case_dir(按固定路径存放的kernelso), + // input_host_addr/input_size/input_num, // output_host_addr/output_size/output_num, param信息(如果有权重) - - TsmModel *new_model = new TsmModel(); // 设备相关参数已在dev中 + TsmModel *new_model = new TsmModel(); // 设备相关参数已在dev中 std::string option = "-O2"; CompileOption compl_option = {}; compl_option.comp_enable = 0; - compl_option.chip_x = 1; //单卡 + compl_option.chip_x = 1; // 单卡 compl_option.chip_y = 1; compl_option.check_enable = true; compl_option.enable_kcore_bin = 1; compl_option.enable_kcore_so = 1; - new_model->case_dir = "/tmp/todo"; // 参数传入, kernelso路径,同streambin/kcorebin文件夹路径 - - if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != RET_SUCCESS) { - for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { - if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) { - printf("[Chip id %u] tx_engine: tx_reset, failed!\n", dev_index); - } else { - printf("[Chip id %u] tx_engine: tx_reset, success!\n", dev_index); - } + new_model->case_dir = + "/tmp/todo"; // 参数传入, kernelso路径,同streambin/kcorebin文件夹路径 + + if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != + RET_SUCCESS) { + for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { + if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) { + printf("[Chip id %u] tx_engine: tx_reset, failed!\n", dev_index); + } else { + printf("[Chip id %u] tx_engine: tx_reset, success!\n", dev_index); } - printf("TsmCompile failed.\n"); - return NULL; + } + printf("TsmCompile failed.\n"); + return NULL; } std::vector kmodel_vec = {new_model}; @@ -450,139 +456,155 @@ static PyObject *launch(PyObject *self, PyObject* args) { uint32_t input_num = 2; // TODO:根据kernel参数填写 uint32_t output_num = 1; // TODO:根据kernel参数填写 uint32_t param_num = 0; // 权重数 - std::shared_ptr hostboot = std::make_shared(input_num, output_num, param_num); + std::shared_ptr hostboot = + std::make_shared(input_num, output_num, param_num); std::shared_ptr chip_info; // 填充chipinfo信息 chip_info->input_num = input_num; chip_info->output_num = output_num; chip_info->param_num = param_num; - chip_info->imm_size = 0; //缓存大小暂设置为0,和算子实际相关; + chip_info->imm_size = 0; // 缓存大小暂设置为0,和算子实际相关; // chip_info->tile_num = 16; // 未使用 // chip_info->tile_x = 4; // 未使用 // chip_info->tile_y = 4; // 未使用 - for(uint32_t i = 0; i < chip_info->input_num; ++i) { - chip_info->input_size[i] = 6; // TODO:填写实际输入大小 - chip_info->input_host_addr = std::vector{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; // TODO: 填写实际输入地址 + for (uint32_t i = 0; i < chip_info->input_num; ++i) { + chip_info->input_size[i] = 6; // TODO:填写实际输入大小 + chip_info->input_host_addr = std::vector{ + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; // TODO: 填写实际输入地址 } - for(uint32_t i = 0; i < chip_info->output_num; ++i) { - chip_info->output_size[i] = 1; // TODO:填写实际输出大小 - chip_info->output_host_addr = std::vector{0x0}; // TODO: 填写实际输出地址 + for (uint32_t i = 0; i < chip_info->output_num; ++i) { + chip_info->output_size[i] = 1; // TODO:填写实际输出大小 + chip_info->output_host_addr = + std::vector{0x0}; // TODO: 填写实际输出地址 } - //for(uint32_t i = 0; i < chip_info->param_num; ++i) { - // chip_info->param_size[i] = 0; // TODO:填写实际权重大小 - // chip_info->param_host_addr = 0x0; - //} + // for(uint32_t i = 0; i < chip_info->param_num; ++i) { + // chip_info->param_size[i] = 0; // TODO:填写实际权重大小 + // chip_info->param_host_addr = 0x0; + // } // prepare data/ load kernel/run/unload kernel/get out data/release memory for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { - // input prepare - prepare_input(devices, dev_index, chip_info); - // output prepare - prepare_output(devices, dev_index, chip_info); - - uint32_t chip_id = devices[dev_index]->chip_id; - TsmSetMonitorInfo(devices[dev_index]); - - // load kernel - char module_symbol[] = "main_kernel"; - TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); - printf("TsmLoadKernel finish!...\n"); - - printf("[Chip id %u] Set boot-params...\n", chip_id); - size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); - TsmDevicePtr dev_dyn_mods_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != RET_SUCCESS) { - return NULL; - } - TsmDevicePtr dev_tlv_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_tlv_ptr, sizeof(DynTLV_DynMods)) != RET_SUCCESS) { - return NULL; - } - - TsmDevicePtr dev_kernel_head_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_kernel_head_ptr, sizeof(Kernel_Head)) != RET_SUCCESS) { - return NULL; - } - TsmDevicePtr dev_kernel_param_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_kernel_param_ptr, sizeof(Kernel_Param)) != RET_SUCCESS) { - return NULL; - } - - Kernel_Head *host_kernel_head_ptr = (Kernel_Head*)malloc(sizeof(Kernel_Head)); - Kernel_Param *host_kernel_param_ptr = (Kernel_Param*)malloc(sizeof(Kernel_Param)); - - host_kernel_head_ptr->param_type = 1; - host_kernel_head_ptr->param_num = 1; // Number of kernel arguments - host_kernel_head_ptr->param_addr = dev_kernel_param_ptr; // 将kernel 使用的参数地址赋值 - - // TODO: Setup the triton kernel arguments - host_kernel_param_ptr->gridX = 512; - host_kernel_param_ptr->gridY = 512; - host_kernel_param_ptr->gridZ = 512; - - TsmMemcpyH2D(dev_kernel_head_ptr, host_kernel_head_ptr, sizeof(Kernel_Head)); - TsmMemcpyH2D(dev_kernel_param_ptr, host_kernel_param_ptr, sizeof(Kernel_Param)); - - free(host_kernel_head_ptr); - free(host_kernel_param_ptr); + // input prepare + prepare_input(devices, dev_index, chip_info); + // output prepare + prepare_output(devices, dev_index, chip_info); + + uint32_t chip_id = devices[dev_index]->chip_id; + TsmSetMonitorInfo(devices[dev_index]); + + // load kernel + char module_symbol[] = "main_kernel"; + TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); + printf("TsmLoadKernel finish!...\n"); + + printf("[Chip id %u] Set boot-params...\n", chip_id); + size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); + TsmDevicePtr dev_dyn_mods_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != + RET_SUCCESS) { + return NULL; + } + TsmDevicePtr dev_tlv_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_tlv_ptr, + sizeof(DynTLV_DynMods)) != RET_SUCCESS) { + return NULL; + } - // TODO: No such API - setHostBoot(chip_info, hostboot); - set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, dev_tlv_ptr, dev_kernel_head_ptr); + TsmDevicePtr dev_kernel_head_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_kernel_head_ptr, + sizeof(Kernel_Head)) != RET_SUCCESS) { + return NULL; + } + TsmDevicePtr dev_kernel_param_ptr; + if (TsmDeviceMalloc(devices[dev_index], dev_kernel_param_ptr, + sizeof(Kernel_Param)) != RET_SUCCESS) { + return NULL; + } - TsmDevicePtr bootpm_dev; - if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, hostboot->get_maxlen()) != RET_SUCCESS) { - return NULL; - } - if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), hostboot->get_maxlen()) != RET_SUCCESS) { - return NULL; - } + Kernel_Head *host_kernel_head_ptr = + (Kernel_Head *)malloc(sizeof(Kernel_Head)); + Kernel_Param *host_kernel_param_ptr = + (Kernel_Param *)malloc(sizeof(Kernel_Param)); + + host_kernel_head_ptr->param_type = 1; + host_kernel_head_ptr->param_num = 1; // Number of kernel arguments + host_kernel_head_ptr->param_addr = + dev_kernel_param_ptr; // 将kernel 使用的参数地址赋值 + + // TODO: Setup the triton kernel arguments + host_kernel_param_ptr->gridX = 512; + host_kernel_param_ptr->gridY = 512; + host_kernel_param_ptr->gridZ = 512; + + TsmMemcpyH2D(dev_kernel_head_ptr, host_kernel_head_ptr, + sizeof(Kernel_Head)); + TsmMemcpyH2D(dev_kernel_param_ptr, host_kernel_param_ptr, + sizeof(Kernel_Param)); + + free(host_kernel_head_ptr); + free(host_kernel_param_ptr); + + // TODO: No such API + setHostBoot(chip_info, hostboot); + set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, dev_tlv_ptr, + dev_kernel_head_ptr); + + TsmDevicePtr bootpm_dev; + if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, + hostboot->get_maxlen()) != RET_SUCCESS) { + return NULL; + } + if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), + hostboot->get_maxlen()) != RET_SUCCESS) { + return NULL; + } - if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) { - printf("TsmRun bootpm_dev failed.\n"); - return NULL; - } + if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) { + printf("TsmRun bootpm_dev failed.\n"); + return NULL; + } - // 卸载kernel - TsmUnloadKernel(devices[dev_index], kmodel_vec); + // 卸载kernel + TsmUnloadKernel(devices[dev_index], kmodel_vec); - // 得到输出数据,并进行处理 - printf("[Chip id %u] Copy output from device...\n", chip_id); - if (kernel_result_process(devices, dev_index, hostboot, chip_info, bootpm_dev, new_model->case_dir) != RET_SUCCESS) { - printf("free dev memory failed.\n"); - return NULL; - } - if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) { - printf("free dev memory failed.\n"); - return NULL; - } - //释放多图相关tlv - if (TsmDeviceFree(dev_kernel_head_ptr) != RET_SUCCESS) { - printf("free dev_kernel_head_ptr failed.\n"); - return NULL; - } - if (TsmDeviceFree(dev_kernel_param_ptr) != RET_SUCCESS) { - printf("free dev_kernel_param_ptr failed.\n"); - return NULL; - } + // 得到输出数据,并进行处理 + printf("[Chip id %u] Copy output from device...\n", chip_id); + if (kernel_result_process(devices, dev_index, hostboot, chip_info, + bootpm_dev, new_model->case_dir) != RET_SUCCESS) { + printf("free dev memory failed.\n"); + return NULL; + } + if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) { + printf("free dev memory failed.\n"); + return NULL; + } + // 释放多图相关tlv + if (TsmDeviceFree(dev_kernel_head_ptr) != RET_SUCCESS) { + printf("free dev_kernel_head_ptr failed.\n"); + return NULL; + } + if (TsmDeviceFree(dev_kernel_param_ptr) != RET_SUCCESS) { + printf("free dev_kernel_param_ptr failed.\n"); + return NULL; + } - if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) { - printf("free dev_dyn_mods_ptr failed.\n"); - return NULL; - } - if (TsmDeviceFree(dev_tlv_ptr) != RET_SUCCESS) { - printf("free dev_tlv_ptr failed.\n"); - return NULL; - } + if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) { + printf("free dev_dyn_mods_ptr failed.\n"); + return NULL; + } + if (TsmDeviceFree(dev_tlv_ptr) != RET_SUCCESS) { + printf("free dev_tlv_ptr failed.\n"); + return NULL; + } - printf("[dev_index %u] Set Terminal Info...\n", dev_index); - if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) { - printf("TsmSetTerminate failed.\n"); - return NULL; - } + printf("[dev_index %u] Set Terminal Info...\n", dev_index); + if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) { + printf("TsmSetTerminate failed.\n"); + return NULL; + } #if 0 if (freeTensorData(chip_id, chip_info) != RET_SUCCESS) { printf("free tensor data dev memory failed.\n"); @@ -593,7 +615,6 @@ static PyObject *launch(PyObject *self, PyObject* args) { Py_RETURN_NONE; } - static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided binary into Tx81 driver"}, @@ -604,13 +625,10 @@ static PyMethodDef ModuleMethods[] = { {NULL, NULL, 0, NULL} // sentinel }; -static struct PyModuleDef ModuleDef = { - PyModuleDef_HEAD_INIT, - "tx81_utils", - NULL, // documentation - -1, // size - ModuleMethods -}; +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "tx81_utils", + NULL, // documentation + -1, // size + ModuleMethods}; PyMODINIT_FUNC PyInit_tx81_utils(void) { PyObject *m = PyModule_Create(&ModuleDef); @@ -621,4 +639,4 @@ PyMODINIT_FUNC PyInit_tx81_utils(void) { PyModule_AddFunctions(m, ModuleMethods); return m; -} \ No newline at end of file +} diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py index bafa51d28..071b4a1b0 100644 --- a/third_party/tsingmicro/backend/driver.py +++ b/third_party/tsingmicro/backend/driver.py @@ -18,15 +18,17 @@ from triton.backends.compiler import GPUTarget dirname = os.path.dirname(os.path.realpath(__file__)) -include_dirs = [os.path.join(dirname, "include"), - os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), - os.path.join(sysconfig.get_path('platlib'), "torch", "include"), - os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), - os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include")] -library_dirs = [os.path.join(dirname, "lib"), - os.path.join(sysconfig.get_path('platlib'), "torch", "lib")] +include_dirs = [ + os.path.join(dirname, "include"), + os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include"), + os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), + os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include") +] +library_dirs = [os.path.join(dirname, "lib"), os.path.join(sysconfig.get_path('platlib'), "torch", "lib")] libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] + # Path configuration for cross compilation def _get_llvm_bin_path(bin_name: str) -> str: path = os.getenv("LLVM_BINARY_DIR", "") @@ -34,18 +36,21 @@ def _get_llvm_bin_path(bin_name: str) -> str: raise Exception("LLVM_BINARY_DIR is not set.") return os.path.join(path, bin_name) + def _get_libc_root() -> str: path = os.getenv("LIB_C_ROOT", "") if path == "": raise Exception("LIB_C_ROOT is not set.") return path + def _get_vendor_runtime_path() -> str: path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") if path == "": raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") return path + def _dump_ir_if_needed(files): path = os.getenv("ZTC_DUMP_PATH", "") if not path: @@ -55,6 +60,7 @@ def _dump_ir_if_needed(files): for f in files: shutil.copy(f, os.path.join(path, os.path.basename(f))) + # Build a native ELF on the platform running this python script def compile_native(src, name): fname = "native_" + name @@ -77,6 +83,7 @@ def compile_native(src, name): spec.loader.exec_module(mod) return mod + # Build a accelerator controller ELF def compile_accelerator(src, name, ext): name = "npu_" + name @@ -95,56 +102,40 @@ def compile_accelerator(src, name, ext): _dump_ir_if_needed([src_path]) clang_path = _get_llvm_bin_path("clang") # Compile - subprocess.check_call([clang_path, src_path, - "-O2", - "-c", - "-fPIC", - f"-I{libc_inc}", - "--target=riscv64-unknown-elf", - "-march=rv64imafdc", - "-o", - dst_path]) + subprocess.check_call([ + clang_path, src_path, "-O2", "-c", "-fPIC", f"-I{libc_inc}", "--target=riscv64-unknown-elf", + "-march=rv64imafdc", "-o", dst_path + ]) with tempfile.TemporaryDirectory() as tmpdir: # FIXME: Hardcoded path #dst_path = os.path.join(tmpdir, f"{name}.so") dst_path = "/tmp/kernel.so" libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib", "rv64imafdc", "lp64d") - libcrt_lib = os.path.join(_get_libc_root(), "lib", "gcc", "riscv64-unknown-elf", "15.0.0", "rv64imafdc", "lp64d") + libcrt_lib = os.path.join(_get_libc_root(), "lib", "gcc", "riscv64-unknown-elf", "15.0.0", "rv64imafdc", + "lp64d") libvr_path = _get_vendor_runtime_path() clang_path = _get_llvm_bin_path("clang") # Link wrapper, kernel with Tx81 crt and intrinsics(libkcorert.a) - subprocess.check_call([clang_path, - "-nostdlib", - # FIXME: Hardcoded path - "/tmp/wrapper.o", - "/tmp/kernel.o", - "-O2", - "--target=riscv64-unknown-elf", - "-march=rv64imafdc", - "-fPIC", + subprocess.check_call([ + clang_path, "-nostdlib", + # FIXME: Hardcoded path + "/tmp/wrapper.o", "/tmp/kernel.o", "-O2", "--target=riscv64-unknown-elf", "-march=rv64imafdc", "-fPIC", # "-shared", # ELF toolchain doesn't support -shared - f"-L{libvr_path}", - f"-L{libc_lib}", - f"-L{libcrt_lib}", + f"-L{libvr_path}", f"-L{libc_lib}", f"-L{libcrt_lib}", # Allow libkcorert symbol overwrite libc symbols, libkcorert # should be specified before libc - "-Wl,--allow-multiple-definition", - "-lvr", # Wrapper API of Tx81 intrinsic + "-Wl,--allow-multiple-definition", "-lvr", # Wrapper API of Tx81 intrinsic "-lkcorert", # Tx81 intrinsic API - "-lc", - "-lm", - "-lgcc", - "-T", - f"{libvr_path}/gcc_tx8_smarth.ld", - "-o", - dst_path]) + "-lc", "-lm", "-lgcc", "-T", f"{libvr_path}/gcc_tx8_smarth.ld", "-o", dst_path + ]) _dump_ir_if_needed([dst_path]) with open(dst_path, 'rb') as f: so = f.read() return so + # -------------------- Launcher ---------------------------- def _ty_to_cpp(ty): if ty[0] == '*': @@ -167,27 +158,30 @@ def _ty_to_cpp(ty): "fp64": "double", }[ty] + def _extracted_type(ty): if ty[0] == '*': return "PyObject*" return _ty_to_cpp(ty) + def _format_of(ty): return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", }[ty] + # This function makes a single kernel invoker which wraps all the input args into # a single input buffer. def make_kernel_wrapper_v2(constants, signature, kernel_name): @@ -328,6 +322,7 @@ def make_kernel_wrapper(constants, signature, kernel_name): }} """ + def make_launcher(constants, signature, kernel_name): # Basic declarations arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) @@ -336,7 +331,10 @@ def make_launcher(constants, signature, kernel_name): args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' # Parameters to pass to the kernel function - kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" for i, ty in signature.items() if i not in constants) + kernel_parameters = ', '.join( + f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" + for i, ty in signature.items() + if i not in constants) kernel_parameters += ', ' if kernel_parameters else '' return f""" @@ -917,6 +915,7 @@ def make_launcher(constants, signature, kernel_name): }} """ + class CrossUtils(object): def __new__(cls): @@ -931,6 +930,7 @@ def __init__(self): self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties + # Launch cross compiled runtime program on controller class CrossLauncher(object): @@ -952,7 +952,6 @@ def __init__(self, src, metadata): mod = compile_native(launcher_src, "__triton_launcher") self.launch = mod.launch - def __call__(self, *args, **kwargs): # args: 0: gridX, 1: gridY, 2: gridZ, # 3: kernel_metadata?, 4: launch_metadata?, diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt index 9470c08db..fa4c0bbb8 100644 --- a/third_party/tsingmicro/crt/CMakeLists.txt +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -110,4 +110,4 @@ install(TARGETS ${VENDOR_RUNTIME_LIB} # Install headers (optional) file(GLOB_RECURSE VENDOR_HEADERS Target/lib/${TARGET}/*.h) install(FILES ${VENDOR_HEADERS} DESTINATION include/${TARGET}) -install(FILES Target/lib/${TARGET}/libkcorert.a DESTINATION lib/${TARGET}) \ No newline at end of file +install(FILES Target/lib/${TARGET}/libkcorert.a DESTINATION lib/${TARGET}) diff --git a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld index 70943de73..db5c48fc0 100644 --- a/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld +++ b/third_party/tsingmicro/crt/gcc_flash_xiaohui.ld @@ -238,7 +238,7 @@ SECTIONS __bss_end__ = .; __end = . ; end = . ; - } > REGION_BSS + } > REGION_BSS ._user_heap (NOLOAD): { . = ALIGN(0x8) ; *(.stack*) diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h index adf93a8e6..306624447 100644 --- a/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h +++ b/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h @@ -1,23 +1,24 @@ #ifndef INSTR_ADAPTER_H #define INSTR_ADAPTER_H #include +#include #include #include -#include -//#include "common_base.h" -#include "instr_def.h" +// #include "common_base.h" #include "instr_adapter_plat.h" +#include "instr_def.h" #ifndef USING_RISCV #define __CHECK_INSTR__ #endif -//#define __PLAT_FREERTOS__ -// #define RECORD_INSTR_INVALID +// #define __PLAT_FREERTOS__ +// #define RECORD_INSTR_INVALID #define SPM_LOWER_BOUND 0 #define SPM_UPPER_BOUND 0x2EFFFF #define DDR_LOWER_BOUND 0x280000000 -#define IS_WITHIN_SPM_BOUND(value) (((value) >= SPM_LOWER_BOUND) && ((value) <= SPM_UPPER_BOUND)) +#define IS_WITHIN_SPM_BOUND(value) \ + (((value) >= SPM_LOWER_BOUND) && ((value) <= SPM_UPPER_BOUND)) #define IS_WITHIN_DDR_BOUND(value) ((value) >= DDR_LOWER_BOUND) // 设置 times (0-7 位) #define TIMES_INVALID_OFFET 0 @@ -27,11 +28,11 @@ #define FIRST_INVALID_BARRIER 36 typedef struct InstrInvalidInfo { - volatile uint64_t ne_error_info; - volatile uint64_t ct_error_info; - volatile uint64_t td_error_info; - volatile uint64_t rdma_error_info; - volatile uint64_t wdma_error_info; + volatile uint64_t ne_error_info; + volatile uint64_t ct_error_info; + volatile uint64_t td_error_info; + volatile uint64_t rdma_error_info; + volatile uint64_t wdma_error_info; } InstrInvalidInfo; /* @@ -53,7 +54,6 @@ uint32_t __execute_wdma(TsmWdmaInstr *instr); void __execute_sc(SC_Param *instr); uint64_t TsmExecute(void *instr); - /*=================================debug=================================*/ void set_device_ddr_base(uint64_t base); uint64_t get_device_ddr_base(); diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h index 79f8b0165..f01828512 100644 --- a/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h +++ b/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h @@ -3,453 +3,744 @@ // You should define something, according to your device-type -// ==================== if you run in Tx8-simulator ===================================================== +// ==================== if you run in Tx8-simulator +// ===================================================== #include - -//#include "oplib_depend_api.h" +// #include "oplib_depend_api.h" #ifdef __cplusplus extern "C" { #endif typedef struct Data_Shape { - uint16_t n; - uint16_t h; - uint16_t w; - uint16_t c; + uint16_t n; + uint16_t h; + uint16_t w; + uint16_t c; } Data_Shape; typedef struct St_Elem_Shape { - uint32_t elem_count; - uint32_t unit_elem_count; - uint32_t full_elem_count; - uint32_t full_unit_elem_count; + uint32_t elem_count; + uint32_t unit_elem_count; + uint32_t full_elem_count; + uint32_t full_unit_elem_count; } St_Elem_Shape; typedef struct St_StrideIteration { - uint32_t stride0; - uint32_t iteration0; - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; + uint32_t stride0; + uint32_t iteration0; + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; } St_StrideIteration; /*=================================C class=================================*/ typedef struct TsmConv { - void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, Data_Format fmt); - void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, Data_Format fmt); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, Data_Format fmt); - void (*SetOpType)(TsmNeInstr *instr, uint8_t type); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //- negative axis - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //+ positive axis - void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); - void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, Data_Format fmt); - void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); - void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); - void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, uint32_t Sx, uint32_t Sy); - void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, uint8_t zp_cur); - /* data */ + void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, + Data_Format fmt); + void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, + Data_Format fmt); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, + Data_Format fmt); + void (*SetOpType)(TsmNeInstr *instr, uint8_t type); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t scale_addr); //- negative axis + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t scale_addr); //+ positive axis + void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); + void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, + Data_Format fmt); + void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, + uint32_t left, uint32_t right); + void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, + uint32_t left, uint32_t right); + void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, + uint32_t Sx, uint32_t Sy); + void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, + uint8_t zp_cur); + /* data */ } TsmConv; typedef struct TsmDepthwiseConv { - void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, Data_Format fmt); - void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, Data_Format fmt); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, Data_Format fmt); - void (*SetOpType)(TsmNeInstr *instr, uint8_t type); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //- negative axis - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t scale_addr); //+ positive axis - void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); - void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); - void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, uint32_t left, uint32_t right); - void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, uint32_t Sx, uint32_t Sy); - void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, uint8_t zp_cur); - /* data */ + void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, + Data_Format fmt); + void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, + Data_Format fmt); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, + Data_Format fmt); + void (*SetOpType)(TsmNeInstr *instr, uint8_t type); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t scale_addr); //- negative axis + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t scale_addr); //+ positive axis + void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); + void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, + uint32_t left, uint32_t right); + void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, + uint32_t left, uint32_t right); + void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, + uint32_t Sx, uint32_t Sy); + void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, + uint8_t zp_cur); + /* data */ } TsmDepthwiseConv; typedef struct TsmGemm { - void (*AddInput)(TsmNeInstr *instr, uint64_t L_addr, uint64_t R_addr, Data_Format in_fmt); - void (*ConfigMKN)(TsmNeInstr *instr, uint32_t M, uint32_t K, uint32_t N); - void (*ConfigBatch)(TsmNeInstr *instr, uint32_t Left_batch, uint32_t Right_batch); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Format Out_fmt); - void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, Data_Format fmt); - void (*SetTransflag)(TsmNeInstr *instr, uint8_t L_trans, uint8_t R_trans); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_left, uint8_t zp_right); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t addr); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t addr); - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, uint64_t addr); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - - /* data */ + void (*AddInput)(TsmNeInstr *instr, uint64_t L_addr, uint64_t R_addr, + Data_Format in_fmt); + void (*ConfigMKN)(TsmNeInstr *instr, uint32_t M, uint32_t K, uint32_t N); + void (*ConfigBatch)(TsmNeInstr *instr, uint32_t Left_batch, + uint32_t Right_batch); + void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Format Out_fmt); + void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, + Data_Format fmt); + void (*SetTransflag)(TsmNeInstr *instr, uint8_t L_trans, uint8_t R_trans); + void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_left, + uint8_t zp_right); + void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t addr); + void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t addr); + void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, + uint64_t addr); + void (*EnableRelu)(TsmNeInstr *instr); + void (*EnableLeakyRelu)(TsmNeInstr *instr); + void (*DisableRelu)(TsmNeInstr *instr); + void (*DisableLeakyRelu)(TsmNeInstr *instr); + + /* data */ } TsmGemm; typedef struct TsmRdma { - void (*AddSrcDst)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, Data_Format fmt); - void (*ConfigStrideIteration)(TsmRdmaInstr *instr, uint32_t elem_count, uint32_t stride0, uint32_t iteration0, - uint32_t stride1, uint32_t iteration1, uint32_t stride2, uint32_t iteration2); - void (*Rdma1d)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, - uint32_t format); //只有stride0,和iteration0,内层循环, 只复制一次 + void (*AddSrcDst)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, + Data_Format fmt); + void (*ConfigStrideIteration)(TsmRdmaInstr *instr, uint32_t elem_count, + uint32_t stride0, uint32_t iteration0, + uint32_t stride1, uint32_t iteration1, + uint32_t stride2, uint32_t iteration2); + void (*Rdma1d)( + TsmRdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, + uint32_t format); // 只有stride0,和iteration0,内层循环, 只复制一次 } TsmRdma; typedef struct TsmWdma { - void (*AddSrcDst)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, Data_Format fmt); - void (*ConfigStrideIteration)(TsmWdmaInstr *instr, uint32_t elem_count, uint32_t stride0, uint32_t iteration0, - uint32_t stride1, uint32_t iteration1, uint32_t stride2, uint32_t iteration2); - void (*Wdma1d)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, - uint32_t format); //只有stride0,和iteration0,内层循环, 只复制一次 + void (*AddSrcDst)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, + Data_Format fmt); + void (*ConfigStrideIteration)(TsmWdmaInstr *instr, uint32_t elem_count, + uint32_t stride0, uint32_t iteration0, + uint32_t stride1, uint32_t iteration1, + uint32_t stride2, uint32_t iteration2); + void (*Wdma1d)( + TsmWdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, + uint32_t format); // 只有stride0,和iteration0,内层循环, 只复制一次 } TsmWdma; - - /*=================================CGRA=================================*/ typedef struct TsmArith { - void(*AbsVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void(*RecipVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void(*SquareVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void(*SqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void(*RsqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void(*NegVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*MaxVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE reserved, Data_Format fmt); - void (*MaxVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, RND_MODE reserved, Data_Format fmt); - void (*MaxVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); - void (*MaxVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE reserved, Data_Format fmt); - void(*MinVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, RND_MODE reserved, Data_Format fmt); - void(*MinVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE reserved, Data_Format fmt); - void(*MinVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); - void(*MinVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE reserved, Data_Format fmt); - void (*AddVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode, Data_Format fmt); - void (*AddVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*AddVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*AddVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); - void (*SubVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode, Data_Format fmt); - void (*SubVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*SubVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*SubVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); - void (*MulVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode, Data_Format fmt); - void (*MulVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*MulVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*MulVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); - void (*DivVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode, Data_Format fmt); - void (*DivVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, uint64_t dst_addr, - uint32_t elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*DivVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*DivVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, Data_Format fmt); + void (*AbsVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*RecipVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*SquareVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*SqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*RsqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*NegVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*MaxVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, + Data_Format fmt); + void (*MaxVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, + Data_Format fmt); + void (*MaxVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); + void (*MaxVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE reserved, + Data_Format fmt); + void (*MinVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, + Data_Format fmt); + void (*MinVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, + Data_Format fmt); + void (*MinVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); + void (*MinVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE reserved, + Data_Format fmt); + void (*AddVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*AddVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*AddVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*AddVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, + Data_Format fmt); + void (*SubVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*SubVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*SubVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*SubVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, + Data_Format fmt); + void (*MulVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*MulVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*MulVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*MulVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, + Data_Format fmt); + void (*DivVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*DivVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, + Data_Format fmt); + void (*DivVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); + void (*DivVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, RND_MODE rnd_mode, + Data_Format fmt); } TsmArith; - typedef struct TsmRelation { - void (*EqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*EqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*EqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*EqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*UnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolUnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*UnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolUnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*UnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolUnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*UnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolUnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*GreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*GreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*GreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolGreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*GreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolGreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*GreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*GreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*GreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolGreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*GreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolGreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*LessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolLessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*LessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolLessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*LessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolLessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*LessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolLessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*LessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolLessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*LessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*BoolLessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*LessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*BoolLessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*LessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolLessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); + void (*EqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*BoolEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*EqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*BoolEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*EqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*BoolEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*EqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + void (*BoolEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + + void (*UnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*BoolUnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*UnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolUnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*UnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*BoolUnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*UnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + void (*BoolUnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*GreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*GreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*GreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*BoolGreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*GreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*BoolGreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*GreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*BoolGreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*GreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolGreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*GreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*BoolGreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*GreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + void (*BoolGreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*LessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolLessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*LessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolLessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*LessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*BoolLessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*LessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + void (*BoolLessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*LessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*BoolLessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*LessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*BoolLessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t convst_value, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*LessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*BoolLessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + Data_Format fmt); + void (*LessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num, + Data_Format fmt); + void (*BoolLessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); } TsmRelation; typedef struct TsmLogic { - void (*NotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*AndVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*OrVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*XorVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*AndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*OrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*XorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count, Data_Format fmt); - void (*AndVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*OrVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*XorVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*BoolNotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BoolAndV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BoolOrV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BoolXorV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BoolAndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); - void (*BoolOrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); - void (*BoolXorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, uint32_t unit_elem_count); - void (*BoolAndVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); - void (*BoolOrVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); - void (*BoolXorVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, uint32_t full_unit_elem_num); + void (*NotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*AndVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*OrVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*XorVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*AndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*OrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*XorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, Data_Format fmt); + void (*AndVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*OrVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + void (*XorVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count, uint32_t full_elem_num, + uint32_t full_unit_elem_num, Data_Format fmt); + + void (*BoolNotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count); + void (*BoolAndV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BoolOrV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BoolXorV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BoolAndVuV)(TsmLogicInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count); + void (*BoolOrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, + uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count); + void (*BoolXorVuV)(TsmLogicInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, + uint32_t unit_elem_count); + void (*BoolAndVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num); + void (*BoolOrVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num); + void (*BoolXorVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, + uint64_t unit_addr, uint64_t dst_addr, + uint32_t elem_count, uint32_t unit_elem_count, + uint32_t full_elem_num, uint32_t full_unit_elem_num); } TsmLogic; typedef struct TsmTranscendental { - void (*Log2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Ln)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Pow2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Exp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Explp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Sin)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Cos)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Log2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Ln)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Pow2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Exp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Explp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Sin)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Cos)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); } TsmTranscendental; typedef struct TsmActivation { - void (*Tanh)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Sigmoid)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Relu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Satrelu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Leakyrelu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Softplus)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Tanh)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Sigmoid)(TsmActivationInstr *instr, uint64_t src_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Relu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); + void (*Satrelu)(TsmActivationInstr *instr, uint64_t src_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Leakyrelu)(TsmActivationInstr *instr, uint64_t src_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*Softplus)(TsmActivationInstr *instr, uint64_t src_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); } TsmActivation; typedef struct TsmReduce { - void (*ReduceSum)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, - Data_Format fmt); - void (*ReduceAvg)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, - Data_Format fmt); - void (*ReduceMax)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, - Data_Format fmt); - void (*ReduceMin)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, uint32_t dim, Data_Shape shape, - Data_Format fmt); + void (*ReduceSum)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t dim, Data_Shape shape, Data_Format fmt); + void (*ReduceAvg)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t dim, Data_Shape shape, Data_Format fmt); + void (*ReduceMax)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t dim, Data_Shape shape, Data_Format fmt); + void (*ReduceMin)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, + uint32_t dim, Data_Shape shape, Data_Format fmt); } TsmReduce; typedef struct TsmPool { - void (*MaxPool)(TsmPoolInstr *instr, uint64_t src0, Data_Shape src_shape, uint64_t dst, Data_Shape pad, - Data_Shape swr_shape, Data_Format fmt); - void (*AvgPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, - Data_Shape swr_shape, Data_Format fmt); - void (*SumPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, - Data_Shape swr_shape, Data_Format fmt); - void (*MinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, Data_Shape pad, - Data_Shape swr_shape, Data_Format fmt); - void (*IndexdMinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_arg, - uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, Data_Format fmt); - void (*IndexdMaxPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_arg, - uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, Data_Format fmt); + void (*MaxPool)(TsmPoolInstr *instr, uint64_t src0, Data_Shape src_shape, + uint64_t dst, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); + void (*AvgPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, + uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); + void (*SumPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, + uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); + void (*MinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, + uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); + void (*IndexdMinPool)(TsmPoolInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_arg, + uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); + void (*IndexdMaxPool)(TsmPoolInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_arg, + uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, + Data_Format fmt); } TsmPool; typedef struct TsmUnPool { - void (*Unpool)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); - void (*UnpoolAvg)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); - void (*UnpoolIdx)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, Data_Format fmt); + void (*Unpool)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, + uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, + Data_Format fmt); + void (*UnpoolAvg)(TsmUnPoolInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, Data_Shape dst_shape, + Data_Shape swr_shape, Data_Format fmt); + void (*UnpoolIdx)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, + uint64_t dst_addr, Data_Shape dst_shape, + Data_Shape swr_shape, Data_Format fmt); } TsmUnPool; typedef struct TsmMaskDataMove { - void (*MaskMove)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t mask, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*MaskGather)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t index, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*MaskGather_bV)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, uint32_t bitindex, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*MaskMove)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, + uint32_t mask, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*MaskGather)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, + uint32_t index, uint64_t dst_addr, uint32_t elem_count, + Data_Format fmt); + void (*MaskGather_bV)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, + uint32_t bitindex, uint64_t dst_addr, + uint32_t elem_count, Data_Format fmt); } TsmMaskDataMove; - typedef struct TsmConvert { - void (*INT8_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); // Data_Format fmt is INT8 - void (*INT8_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); - void (*INT8_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); - void (*INT8_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, uint64_t dst_addr, uint32_t elem_count); - void (*INT16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); // INT16 - void (*INT16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*INT32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*BF16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BF16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*BF16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*BF16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BF16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - void (*BF16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - - void (*FP16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); // rnd_mode 0~4 - void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - void (*FP16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - - void (*FP32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); // rnd_mode 0~4 - void (*FP32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*TF32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); - void (*TF32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode);// rnd_mode 0~4 - void (*TF32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count); + void (*INT8_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, + uint64_t dst_addr, + uint32_t elem_count); // Data_Format fmt is INT8 + void (*INT8_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, + uint64_t dst_addr, uint32_t elem_count); + void (*INT8_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, + uint64_t dst_addr, uint32_t elem_count); + void (*INT8_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, + uint64_t dst_addr, uint32_t elem_count); + void (*INT16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); // INT16 + void (*INT16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*INT32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*INT32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*BF16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BF16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*BF16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*BF16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BF16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*BF16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + + void (*FP16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode); // rnd_mode 0~4 + void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*FP16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + + void (*FP32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode); // rnd_mode 0~4 + void (*FP32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*FP32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + + void (*TF32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + void (*TF32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); + void (*TF32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, + RND_MODE rnd_mode); // rnd_mode 0~4 + void (*TF32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count); } TsmConvert; typedef struct TsmPeripheral { - void (*Count)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt, uint64_t *wb_data0, uint64_t *wb_data1); - void (*Memset)(TsmDataMoveInstr *instr, uint64_t dst_addr, uint32_t value, uint32_t elem_count, - St_StrideIteration *si, Data_Format fmt); // si.stride is byte size. but ele_count is only element count - void (*Bit2Fp)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*ArgMax)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t elem_count, Data_Format fmt); - void (*ArgMin)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t elem_count, Data_Format fmt); - void (*Bilinear)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst0_addr, Data_Shape src_shape, - Data_Shape dst_shape, int32_t scale_w, int32_t scale_h, Data_Format fmt); - void (*Lut16)(TsmPeripheralInstr *instr, uint64_t src1_addr, uint64_t dst0_addr, uint64_t lut16_addr, - uint32_t src_elem_count, uint32_t lut_elem_count); - void (*Lut32)(TsmPeripheralInstr *instr, uint64_t src1_addr, uint64_t dst0_addr, uint64_t lut32_addr, - uint32_t src_elem_count, uint32_t lut_elem_count); - void (*RandGen)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t src1_addr, uint64_t dst_addr, - uint64_t dst1_addr, uint64_t dst2_addr, uint32_t src_elem_num, Data_Format fmt); - void (*Factorize)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint64_t dst1_addr, uint64_t dst2_addr, - uint32_t src_elem_num); - void (*ElemMask)(TsmPeripheralInstr *instr, uint64_t src0_addr, uint32_t scale, uint64_t dst_addr, uint32_t src_elem_num, Data_Format fmt, - uint32_t prob, RND_MODE rnd_mode); + void (*Count)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt, + uint64_t *wb_data0, uint64_t *wb_data1); + void (*Memset)(TsmDataMoveInstr *instr, uint64_t dst_addr, uint32_t value, + uint32_t elem_count, St_StrideIteration *si, + Data_Format fmt); // si.stride is byte size. but ele_count is + // only element count + void (*Bit2Fp)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); + void (*ArgMax)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint32_t elem_count, Data_Format fmt); + void (*ArgMin)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint32_t elem_count, Data_Format fmt); + void (*Bilinear)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint64_t dst0_addr, Data_Shape src_shape, + Data_Shape dst_shape, int32_t scale_w, int32_t scale_h, + Data_Format fmt); + void (*Lut16)(TsmPeripheralInstr *instr, uint64_t src1_addr, + uint64_t dst0_addr, uint64_t lut16_addr, + uint32_t src_elem_count, uint32_t lut_elem_count); + void (*Lut32)(TsmPeripheralInstr *instr, uint64_t src1_addr, + uint64_t dst0_addr, uint64_t lut32_addr, + uint32_t src_elem_count, uint32_t lut_elem_count); + void (*RandGen)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint64_t src1_addr, uint64_t dst_addr, uint64_t dst1_addr, + uint64_t dst2_addr, uint32_t src_elem_num, Data_Format fmt); + void (*Factorize)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint64_t dst1_addr, uint64_t dst2_addr, + uint32_t src_elem_num); + void (*ElemMask)(TsmPeripheralInstr *instr, uint64_t src0_addr, + uint32_t scale, uint64_t dst_addr, uint32_t src_elem_num, + Data_Format fmt, uint32_t prob, RND_MODE rnd_mode); } TsmPeripheral; typedef struct TsmDataMove { - void (*Mirror)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + void (*Mirror)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, + Data_Format fmt); + void (*Transpose)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, Data_Format fmt); - void (*Transpose)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Rotate90)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, + void (*Rotate90)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Rotate180)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, Data_Format fmt); - void (*Rotate180)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Rotate270)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Nchw2nhwc)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Nhwc2nchw)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Concat)(TsmMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape0, uint64_t src1_addr, - Data_Shape src_shape1, uint64_t dst_addr, Data_Shape dst_shape, uint32_t dims, Data_Format fmt); - void (*Pad)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Shape pad, Data_Format fmt); - void (*Img2col)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, uint64_t src_elem_num, uint64_t dst_elem_num, Data_Shape swr, Data_Shape pdr, - Data_Format fmt); - void (*TensorNom)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*GatherScatter)(TsmDataMoveInstr *instr, uint64_t src0_addr, uint64_t dst_addr, uint32_t size, - St_StrideIteration *src_si, St_StrideIteration *dst_si); + void (*Rotate270)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Nchw2nhwc)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Nhwc2nchw)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*Concat)(TsmMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape0, + uint64_t src1_addr, Data_Shape src_shape1, uint64_t dst_addr, + Data_Shape dst_shape, uint32_t dims, Data_Format fmt); + void (*Pad)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, + uint64_t dst_addr, Data_Shape dst_shape, Data_Shape pad, + Data_Format fmt); + void (*Img2col)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, + uint64_t src_elem_num, uint64_t dst_elem_num, Data_Shape swr, + Data_Shape pdr, Data_Format fmt); + void (*TensorNom)(TsmDataMoveInstr *instr, uint64_t src0_addr, + Data_Shape src_shape, uint64_t dst_addr, + Data_Shape dst_shape, Data_Format fmt); + void (*GatherScatter)(TsmDataMoveInstr *instr, uint64_t src0_addr, + uint64_t dst_addr, uint32_t size, + St_StrideIteration *src_si, St_StrideIteration *dst_si); } TsmDataMove; TsmConv *TsmNewConv(); @@ -489,19 +780,31 @@ void TsmDeletePeripheral(TsmPeripheral *obj); void TsmDeleteDataMove(TsmDataMove *obj); /*=================================STREAM=================================*/ typedef struct TsmStream { - uint32_t (*OnlineStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint32_t (*OfflineStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint32_t (*WaitStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint32_t (*ReqStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint32_t (*PushStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint32_t (*PopStream)(uint32_t core_id_this, uint32_t tile_id, uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, uint64_t stream_id, uint64_t stream_addr); - uint8_t (*wait_finish)(); + uint32_t (*OnlineStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, + uint64_t stream_id, uint64_t stream_addr); + uint32_t (*OfflineStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, + uint32_t remote, uint32_t stream_type, + uint64_t stream_id, uint64_t stream_addr); + uint32_t (*WaitStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, uint32_t remote, + uint32_t stream_type, uint64_t stream_id, + uint64_t stream_addr); + uint32_t (*ReqStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, uint32_t remote, + uint32_t stream_type, uint64_t stream_id, + uint64_t stream_addr); + uint32_t (*PushStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, uint32_t remote, + uint32_t stream_type, uint64_t stream_id, + uint64_t stream_addr); + uint32_t (*PopStream)(uint32_t core_id_this, uint32_t tile_id, + uint32_t core_id, uint32_t channel_id, uint32_t remote, + uint32_t stream_type, uint64_t stream_id, + uint64_t stream_addr); + uint8_t (*wait_finish)(); } TsmStream; TsmStream *TsmNewStream(); void TsmDeleteStream(TsmStream *obj); @@ -516,16 +819,16 @@ uint8_t TsmWaitfinish_bywork(size_t workerid); } #endif -// ==================== if you will use Tx8-Oplib ====================================================== -// #define LOG_PRINT(...) +// ==================== if you will use Tx8-Oplib +// ====================================================== #define LOG_PRINT(...) // #define LOG_ERR(fmt, args...) // #define TSM_FREE free // #define TSM_MALLOC malloc // extern void setreg(int index, uint64_t value); // extern uint64_t getreg(int index); -// ==================== if you run in SOC-freerots/zebu ================================================ -// #include "rce_log.h" +// ==================== if you run in SOC-freerots/zebu +// ================================================ #include "rce_log.h" // #include "csi_kernel.h" // #include "rce_pal.h" // #define LOG_PRINT(fmt, args...) vdk_printf(fmt, ##args) @@ -535,35 +838,36 @@ uint8_t TsmWaitfinish_bywork(size_t workerid); // #define NCC_ADDR 0x01000000 // #define setreg(ADDR, VALUE) // do { -// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); +// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, +// value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); // *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; // } while (0) -// ==================== if you run in kernel-rt(use tile-sim) ========================================== -// #define LOG_PRINT(fmt, args...) printf(fmt, ##args) -// #define LOG_ERR(fmt, args...) printf(fmt, ##args) -// #define TSM_FREE free -// #define TSM_MALLOC malloc -// #include "rce_pal_port.h" -// void setreg(int index, uint64_t value) +// ==================== if you run in kernel-rt(use tile-sim) +// ========================================== #define LOG_PRINT(fmt, args...) +// printf(fmt, ##args) #define LOG_ERR(fmt, args...) printf(fmt, ##args) #define +// TSM_FREE free #define TSM_MALLOC malloc #include "rce_pal_port.h" void +// setreg(int index, uint64_t value) // { -// LOG_PRINT("setreg param: GR: index=0x%X, value=0x%lX(%lu).\n", index, value, value); -// rce_tx_pal_setreg(index, value); +// LOG_PRINT("setreg param: GR: index=0x%X, value=0x%lX(%lu).\n", index, +// value, value); rce_tx_pal_setreg(index, value); // } -// ====================if you run in kernel-rt(use riscv) =============================================== -// #define LOG_PRINT(...) +// ====================if you run in kernel-rt(use riscv) +// =============================================== #define LOG_PRINT(...) // #define LOG_ERR(fmt, args...) // #define TSM_FREE free // #define TSM_MALLOC malloc // #define NCC_ADDR 0x01000000 // #define setreg(ADDR, VALUE) // do { -// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); +// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, +// value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); // *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; // } while (0) -// ====================if you do not need Log ========================================================== -//#define LOG_PRINT(...) -//#define LOG_ERR(fmt, args...) +// ====================if you do not need Log +// ========================================================== +// #define LOG_PRINT(...) +// #define LOG_ERR(fmt, args...) #endif diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_def.h b/third_party/tsingmicro/crt/include/Tx81/instr_def.h index 96c8eda33..c999997c0 100644 --- a/third_party/tsingmicro/crt/include/Tx81/instr_def.h +++ b/third_party/tsingmicro/crt/include/Tx81/instr_def.h @@ -281,73 +281,76 @@ // DTE PMU end typedef enum OP_INSTR_TYPE { - I_CGRA, - I_NEUR, - I_RDMA, - I_WDMA, - I_TDMA, - I_SCALAR, - I_DTE, - I_CSR, + I_CGRA, + I_NEUR, + I_RDMA, + I_WDMA, + I_TDMA, + I_SCALAR, + I_DTE, + I_CSR, } OP_INSTR_TYPE; // instr_type = I_CGRA | I_WORKER1 typedef enum OP_INSTR_WORKER { - I_WORKER0 = 0x0000, - I_WORKER1 = 0x0100, - I_WORKER2 = 0x0200, + I_WORKER0 = 0x0000, + I_WORKER1 = 0x0100, + I_WORKER2 = 0x0200, } OP_INSTR_WORKER; typedef enum RND_MODE { - RND_NEAREST_EVEN, - RND_ZERO, - RND_POS_INF, - RND_NEG_INF, - RND_STOCHASTIC + RND_NEAREST_EVEN, + RND_ZERO, + RND_POS_INF, + RND_NEG_INF, + RND_STOCHASTIC } RND_MODE; - typedef struct Ncc_CT_GR_Ctl_Regs { - uint8_t cmd_valid; // self clear - uint8_t rnd_mode; // 0 :round to nearest even , 1 :round to zero, 2 :round to positive infinity, 3 :round to - // negative infinity, 4 :stochastic round - uint8_t src0_format; // 当CGRATensor_PeriOp_V_V_bit2fp指令,此字段用作dst_format - uint8_t opcode; // 详见CGRATensor指令OPcode.v + uint8_t cmd_valid; // self clear + uint8_t rnd_mode; // 0 :round to nearest even , 1 :round to zero, 2 :round to + // positive infinity, 3 :round to negative infinity, 4 + // :stochastic round + uint8_t + src0_format; // 当CGRATensor_PeriOp_V_V_bit2fp指令,此字段用作dst_format + uint8_t opcode; // 详见CGRATensor指令OPcode.v } Ncc_CT_GR_Ctl_Regs; typedef struct Ncc_CT_GR_Param_Regs { - uint32_t src0; // spm地址 - uint32_t src1; - uint32_t dst0; - uint32_t dst1; - uint32_t dst2; // spm地址 - uint64_t src0_tfr; // nhwc - uint64_t dst_tfr; // nhwc - uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) - uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy - uint64_t elem_count; // vector运算的元素个数 - uint64_t unit_elem_count; // vector运算中的短向量的元素个数(最大为64) - uint64_t int8_scale_val0; // 双线性插值x方向缩放系数(input_w/output_w) - uint64_t int8_scale_val1; // 双线性插值y方向缩放系数(input_h/output_h) - uint64_t int8_quant; // abandon - uint32_t int8_bn_bias; // abandon - uint32_t full_elem_count; // 若干个src_elem_num之和 - uint32_t full_unit_elem_count; // 若干个src_uint_elem_num之和 - uint64_t wb_data0; // The pointer of Return value. [32] DATA_VALID, [31:0] data, - // 函数只有一个返回值时,返回数据写在此寄存器 - uint64_t wb_data1; // The pointer of Return value. [32] DATA_VALID, [31:0] data, - // 函数有两个返回值时,第二个返回数据写在此寄存器,当只有一个返回值时,此寄存器无效 - uint32_t src0_end; // spm地址(src0结束地址), xxx_end = src/dst + 对应操作数在spm中存储范围 - uint32_t src1_end; - uint32_t dst0_end; - uint32_t dst1_end; - uint32_t dst2_end; - uint8_t dims; // 000:C 001:W 010:H 011:N 100:HW 101:HWC + uint32_t src0; // spm地址 + uint32_t src1; + uint32_t dst0; + uint32_t dst1; + uint32_t dst2; // spm地址 + uint64_t src0_tfr; // nhwc + uint64_t dst_tfr; // nhwc + uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) + uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy + uint64_t elem_count; // vector运算的元素个数 + uint64_t unit_elem_count; // vector运算中的短向量的元素个数(最大为64) + uint64_t int8_scale_val0; // 双线性插值x方向缩放系数(input_w/output_w) + uint64_t int8_scale_val1; // 双线性插值y方向缩放系数(input_h/output_h) + uint64_t int8_quant; // abandon + uint32_t int8_bn_bias; // abandon + uint32_t full_elem_count; // 若干个src_elem_num之和 + uint32_t full_unit_elem_count; // 若干个src_uint_elem_num之和 + uint64_t wb_data0; // The pointer of Return value. [32] DATA_VALID, [31:0] + // data, 函数只有一个返回值时,返回数据写在此寄存器 + uint64_t + wb_data1; // The pointer of Return value. [32] DATA_VALID, [31:0] data, + // 函数有两个返回值时,第二个返回数据写在此寄存器,当只有一个返回值时,此寄存器无效 + uint32_t src0_end; // spm地址(src0结束地址), xxx_end = src/dst + + // 对应操作数在spm中存储范围 + uint32_t src1_end; + uint32_t dst0_end; + uint32_t dst1_end; + uint32_t dst2_end; + uint8_t dims; // 000:C 001:W 010:H 011:N 100:HW 101:HWC } Ncc_CT_GR_Param_Regs; typedef struct CT_Param { - uint32_t inter_type; - Ncc_CT_GR_Ctl_Regs ctrl; - Ncc_CT_GR_Param_Regs param; + uint32_t inter_type; + Ncc_CT_GR_Ctl_Regs ctrl; + Ncc_CT_GR_Param_Regs param; } CT_Param; #define TsmArithInstr CT_Param @@ -364,161 +367,168 @@ typedef struct CT_Param { #define TsmReduceInstr CT_Param typedef struct Ncc_NE_GR_Ctl_Regs { - uint8_t sparse_en; - uint8_t cmd_valid; - uint8_t inpsum_format; - uint8_t output_format; - uint8_t input_format; - uint8_t inpsum_en; - uint8_t lrelu_en; // either relu or lrelu - uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum - uint8_t scale_en; - uint8_t bias_en; - uint8_t dilation_conv; // valid as conv backwardconv - uint8_t type; // 0:conv 1:depthwise conv 2:backward conv 3:gemm + uint8_t sparse_en; + uint8_t cmd_valid; + uint8_t inpsum_format; + uint8_t output_format; + uint8_t input_format; + uint8_t inpsum_en; + uint8_t lrelu_en; // either relu or lrelu + uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum + uint8_t scale_en; + uint8_t bias_en; + uint8_t dilation_conv; // valid as conv backwardconv + uint8_t type; // 0:conv 1:depthwise conv 2:backward conv 3:gemm } Ncc_NE_GR_Ctl_Regs; typedef struct Ncc_NE_GR_Param_Regs { - uint32_t src_a; // spm地址(激活/左矩阵) - uint32_t src_w; // spm地址(权重/右矩阵) - uint32_t psum; // spm地址(输入psum) - uint32_t bias; // spm地址(bias) - uint32_t scale_p; // spm地址(正轴scale) - uint32_t scale_n; // spm地址(负轴scale) - uint32_t out; // spm地址(输出psum) - uint64_t tfr_0; // src0 nhwc, [15:0]tensor batch/h/w(范围1~4096);tensor通道数(范围1~16384) - uint64_t tfr_1; // conv: out nhwc, 同上tfr_0 - uint64_t pdr; // pad [15:0]top bottom left right, 分别是上下左右pad的行/列数(范围0~1023) - uint64_t unpdr; // unpad [15:0]top bottom left right - uint64_t swr; // [15:0]Kx(范围1~255) Ky(范围1~255) Sx(范围1~1023) Sy(范围1~1023) - uint64_t dilation; // [15:0]空洞卷积的x方向大小(范围1-1023), [15:0]空洞卷积的y方向大小(范围1-1023) - - uint16_t gemm_lb; // [15:0]左矩阵batch(范围:1~4096) - uint16_t gemm_rb; // [15:0]左矩阵batch(范围:1~4096) - uint16_t gemm_n; // 矩阵运算的矩阵大小参数 - uint16_t gemm_m; // mk*kn---->mn - uint16_t gemm_k; // (范围:1~16384) - uint8_t gemm_l_trs; // 左矩阵转置 - uint8_t gemm_r_trs; // 右矩阵转置 - /* - Quant formula----A_int8:Left input, B_int8: Right input - Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 - Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 - do conv : O_int32 = Sum_{A_int9 * B_int9} - do scale : O_int16 = Clip_int16(O_int32 >> q1) - do scale : O_int9 = Clip_int9((O_int16 * S_int16) >> q2) - out 9bit to 8bit : O_int8 = O_int9 + ZP_O_int8 - */ - uint8_t quant_zp_cur; // 输出零点(0-255). [39:32] - uint8_t quant_reserved; // (0-255). [31:24] conv:unused gemm:right_zp - uint8_t quant_zp_pre; // 输入零点(0-255), [23:16] conv:act_zp gemm:left_zp(范围:0-255) - uint8_t quant_q1; // q1, (范围:0-31),[15:8] - uint8_t quant_q0; // q2, (范围:0-31),[7:0] - - uint32_t sparse_index; // spm地址(稀疏化索引) - uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 - uint32_t srcw_end; - uint32_t psum_end; - uint32_t bias_end; - uint32_t scale_p_end; - uint32_t scale_n_end; - uint32_t out_end; - uint32_t sparse_end; + uint32_t src_a; // spm地址(激活/左矩阵) + uint32_t src_w; // spm地址(权重/右矩阵) + uint32_t psum; // spm地址(输入psum) + uint32_t bias; // spm地址(bias) + uint32_t scale_p; // spm地址(正轴scale) + uint32_t scale_n; // spm地址(负轴scale) + uint32_t out; // spm地址(输出psum) + uint64_t tfr_0; // src0 nhwc, [15:0]tensor + // batch/h/w(范围1~4096);tensor通道数(范围1~16384) + uint64_t tfr_1; // conv: out nhwc, 同上tfr_0 + uint64_t pdr; // pad [15:0]top bottom left right, + // 分别是上下左右pad的行/列数(范围0~1023) + uint64_t unpdr; // unpad [15:0]top bottom left right + uint64_t + swr; // [15:0]Kx(范围1~255) Ky(范围1~255) Sx(范围1~1023) Sy(范围1~1023) + uint64_t dilation; // [15:0]空洞卷积的x方向大小(范围1-1023), + // [15:0]空洞卷积的y方向大小(范围1-1023) + + uint16_t gemm_lb; // [15:0]左矩阵batch(范围:1~4096) + uint16_t gemm_rb; // [15:0]左矩阵batch(范围:1~4096) + uint16_t gemm_n; // 矩阵运算的矩阵大小参数 + uint16_t gemm_m; // mk*kn---->mn + uint16_t gemm_k; // (范围:1~16384) + uint8_t gemm_l_trs; // 左矩阵转置 + uint8_t gemm_r_trs; // 右矩阵转置 + /* + Quant formula----A_int8:Left input, B_int8: Right input + Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 + Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 + do conv : O_int32 = Sum_{A_int9 * B_int9} + do scale : O_int16 = Clip_int16(O_int32 >> q1) + do scale : O_int9 = Clip_int9((O_int16 * S_int16) >> q2) + out 9bit to 8bit : O_int8 = O_int9 + ZP_O_int8 + */ + uint8_t quant_zp_cur; // 输出零点(0-255). [39:32] + uint8_t quant_reserved; // (0-255). [31:24] conv:unused gemm:right_zp + uint8_t quant_zp_pre; // 输入零点(0-255), [23:16] conv:act_zp + // gemm:left_zp(范围:0-255) + uint8_t quant_q1; // q1, (范围:0-31),[15:8] + uint8_t quant_q0; // q2, (范围:0-31),[7:0] + + uint32_t sparse_index; // spm地址(稀疏化索引) + uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 + uint32_t srcw_end; + uint32_t psum_end; + uint32_t bias_end; + uint32_t scale_p_end; + uint32_t scale_n_end; + uint32_t out_end; + uint32_t sparse_end; } Ncc_NE_GR_Param_Regs; typedef struct TsmNeInstr { - uint32_t inter_type; - Ncc_NE_GR_Ctl_Regs ctrl; - Ncc_NE_GR_Param_Regs param; + uint32_t inter_type; + Ncc_NE_GR_Ctl_Regs ctrl; + Ncc_NE_GR_Param_Regs param; } TsmNeInstr; // RDMA / WDMA typedef struct Ncc_DMA_GR_Ctl_Regs { - uint8_t cmd_valid; + uint8_t cmd_valid; } Ncc_DMA_GR_Ctl_Regs; typedef struct Ncc_DMA_GR_Param_Regs { - uint64_t dst; // ddr地址 - uint64_t src; // spm地址 - /* - for(i = 0; i < itera2; i++) - for(j = 0; j < itera1; j++) - for(k = 0; k < itera0; k++) - for(l = 0; l < elem_count; l++) - dst[l + elem_coun * k + elem_coun * src_itera0 * j + elem_coun * src_itera0 * src_itera1 * i] = - \ src[l + k * src_stride0 + j * src_stride1 + i * src_stride2]; - */ - uint32_t stride0; //地址步长 - uint32_t iteration0; // 数据块个数 - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; - uint32_t elem_count; // 最里面维度单次搬运的元素个数 - uint8_t format; // 数据类型 - uint64_t src_end; // src_end = src + ddr中数据存储长度 - uint64_t dst_end; // dst_end = dst + spm中数据存储长度 + uint64_t dst; // ddr地址 + uint64_t src; // spm地址 + /* + for(i = 0; i < itera2; i++) + for(j = 0; j < itera1; j++) + for(k = 0; k < itera0; k++) + for(l = 0; l < elem_count; l++) + dst[l + elem_coun * k + elem_coun * src_itera0 * j + + elem_coun * src_itera0 * src_itera1 * i] = \ src[l + k * src_stride0 + j * + src_stride1 + i * src_stride2]; + */ + uint32_t stride0; // 地址步长 + uint32_t iteration0; // 数据块个数 + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; + uint32_t elem_count; // 最里面维度单次搬运的元素个数 + uint8_t format; // 数据类型 + uint64_t src_end; // src_end = src + ddr中数据存储长度 + uint64_t dst_end; // dst_end = dst + spm中数据存储长度 } Ncc_DMA_GR_Param_Regs; typedef struct DMA_Param { - uint32_t inter_type; - Ncc_DMA_GR_Ctl_Regs ctrl; - Ncc_DMA_GR_Param_Regs param; + uint32_t inter_type; + Ncc_DMA_GR_Ctl_Regs ctrl; + Ncc_DMA_GR_Param_Regs param; } DMA_Param; #define TsmRdmaInstr DMA_Param #define TsmWdmaInstr DMA_Param typedef struct Ncc_TDMA_GR_Ctl_Regs { - uint8_t cmd_valid; // [12] - uint8_t src0_format; // [11:8] - uint8_t opcode; //[7:0] + uint8_t cmd_valid; // [12] + uint8_t src0_format; // [11:8] + uint8_t opcode; //[7:0] } Ncc_TDMA_GR_Ctl_Regs; typedef struct Ncc_TDMA_GR_Param_Regs { - uint32_t src0; - uint32_t src1; - uint32_t dst; - uint64_t src0_tfr; // nhwc c:15~0 - uint64_t dst_tfr; // nhwc - uint64_t pdr; // top bottom left right - uint64_t swr; // kx ky sx sy - uint32_t elem_count; // vector操作的元素个数. memset、gatherscatter指令中代表byte number - /* - for(i=0;igather/scatter, unicast, 1-> scatter, broadcast, 3-> shuffle(3D gather). [8:8] sg_flag: - // 0->scatter, 1->gather, [16:16] dim_flag: only unicast mode(=0), 0->1D transport, 1->2D transport - uint32_t length; // count data bytes - uint8_t dest_num; // if mode[0:0] is 0, then it's value is 1; otherwise it's value is between 1 and 31. - uint32_t stride0; // if mode[0:0] is 0, then stride can be setted, unit is byte. - uint32_t iteration0; // 0: means 1 section; 1: means 2 sectons, and so on. - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; - uint16_t max_axi_num; // [7:0] axi_write_outstanding, [15:8] aix_read_outstanding - uint8_t cmd_valid; // 1: activate dma, 0: no action. - // uint8_t dma_status; // [0:0] 0->unfinished, 1->finished. [8:8] 0/1, record the error of AXI bus or other DMA - // transmission. - uint16_t - mem_burstlen; // [7:0] mem_burst_len_write, default value: 0x10; [15:8] mem_burst_len_read, default value: 0x10 - uint8_t mem_backpressure; // 0x1 - uint8_t mem_read_turbo; // [1:0], 0~2, default value: 0, only block0 valid. + uint8_t channel; // value: 0~3 + uint8_t block; // value: 0~1 + uint64_t dst[32]; // idx: 0~31, [0:31] low bit, [32:39] high bit + uint16_t user_id[32]; // idx: 0~31, [0:15] + uint64_t src; // [31~0]: src_addr_lo, [32~63]: src_addr_hi + uint32_t mode; // [1:0] mode:0->gather/scatter, unicast, 1-> scatter, + // broadcast, 3-> shuffle(3D gather). [8:8] sg_flag: + // 0->scatter, 1->gather, [16:16] dim_flag: only unicast + // mode(=0), 0->1D transport, 1->2D transport + uint32_t length; // count data bytes + uint8_t dest_num; // if mode[0:0] is 0, then it's value is 1; otherwise it's + // value is between 1 and 31. + uint32_t + stride0; // if mode[0:0] is 0, then stride can be setted, unit is byte. + uint32_t iteration0; // 0: means 1 section; 1: means 2 sectons, and so on. + uint32_t stride1; + uint32_t iteration1; + uint32_t stride2; + uint32_t iteration2; + uint16_t + max_axi_num; // [7:0] axi_write_outstanding, [15:8] aix_read_outstanding + uint8_t cmd_valid; // 1: activate dma, 0: no action. + // uint8_t dma_status; // [0:0] 0->unfinished, 1->finished. [8:8] 0/1, record + // the error of AXI bus or other DMA transmission. + uint16_t mem_burstlen; // [7:0] mem_burst_len_write, default value: 0x10; + // [15:8] mem_burst_len_read, default value: 0x10 + uint8_t mem_backpressure; // 0x1 + uint8_t mem_read_turbo; // [1:0], 0~2, default value: 0, only block0 valid. } Ncc_DTE_GR_Param_Regs; typedef enum OP_FUNC_CGRA { - // Arithmetic Operators - OP_FUNC_CGRATensor_ArithOp_V_V_abs = 0, - OP_FUNC_CGRATensor_ArithOp_V_V_recip = 1, - OP_FUNC_CGRATensor_ArithOp_V_V_square = 2, - OP_FUNC_CGRATensor_ArithOp_V_V_sqrt = 3, - OP_FUNC_CGRATensor_ArithOp_V_V_rsqrt = 4, - OP_FUNC_CGRATensor_ArithOp_V_V_neg = 5, - OP_FUNC_CGRATensor_ArithOp_V_VV_max = 6, - OP_FUNC_CGRATensor_ArithOp_V_VS_max = 7, - OP_FUNC_CGRATensor_ArithOp_V_VuV_max = 8, - OP_FUNC_CGRATensor_ArithOp_V_VuV_max_loop = 9, - OP_FUNC_CGRATensor_ArithOp_V_VV_min = 10, - OP_FUNC_CGRATensor_ArithOp_V_VS_min = 11, - OP_FUNC_CGRATensor_ArithOp_V_VuV_min = 12, - OP_FUNC_CGRATensor_ArithOp_V_VuV_min_loop = 13, - OP_FUNC_CGRATensor_ArithOp_V_VV_add = 14, - OP_FUNC_CGRATensor_ArithOp_V_VS_add = 15, - OP_FUNC_CGRATensor_ArithOp_V_VuV_add = 16, - OP_FUNC_CGRATensor_ArithOp_V_VuV_add_loop = 17, - OP_FUNC_CGRATensor_ArithOp_V_VV_sub = 18, - OP_FUNC_CGRATensor_ArithOp_V_VS_sub = 19, - OP_FUNC_CGRATensor_ArithOp_V_VuV_sub = 20, - OP_FUNC_CGRATensor_ArithOp_V_VuV_sub_loop = 21, - OP_FUNC_CGRATensor_ArithOp_V_VV_mul = 22, - OP_FUNC_CGRATensor_ArithOp_V_VS_mul = 23, - OP_FUNC_CGRATensor_ArithOp_V_VuV_mul = 24, - OP_FUNC_CGRATensor_ArithOp_V_VuV_mul_loop = 25, - OP_FUNC_CGRATensor_ArithOp_V_VV_div = 26, - OP_FUNC_CGRATensor_ArithOp_V_VS_div = 27, - OP_FUNC_CGRATensor_ArithOp_V_VuV_div = 28, - OP_FUNC_CGRATensor_ArithOp_V_VuV_div_loop = 29, - - // Relational Operators - OP_FUNC_CGRATensor_RelaOp_V_VV_eq = 30, - OP_FUNC_CGRATensor_RelaOp_bV_VV_eq = 31, - OP_FUNC_CGRATensor_RelaOp_V_VS_eq = 32, - OP_FUNC_CGRATensor_RelaOp_bV_VS_eq = 33, - OP_FUNC_CGRATensor_RelaOp_V_VuV_eq = 34, - OP_FUNC_CGRATensor_RelaOp_V_VuV_eq_loop = 35, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq = 36, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq_loop = 37, - - OP_FUNC_CGRATensor_RelaOp_V_VV_ne = 38, - OP_FUNC_CGRATensor_RelaOp_bV_VV_ne = 39, - OP_FUNC_CGRATensor_RelaOp_V_VS_ne = 40, - OP_FUNC_CGRATensor_RelaOp_bV_VS_ne = 41, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ne = 42, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ne_loop = 43, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne = 44, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne_loop = 45, - - OP_FUNC_CGRATensor_RelaOp_V_VV_ge = 46, - OP_FUNC_CGRATensor_RelaOp_bV_VV_ge = 47, - OP_FUNC_CGRATensor_RelaOp_V_VS_ge = 48, - OP_FUNC_CGRATensor_RelaOp_bV_VS_ge = 49, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ge = 50, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ge_loop = 51, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge = 52, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge_loop = 53, - - OP_FUNC_CGRATensor_RelaOp_V_VV_gt = 54, - OP_FUNC_CGRATensor_RelaOp_bV_VV_gt = 55, - OP_FUNC_CGRATensor_RelaOp_V_VS_gt = 56, - OP_FUNC_CGRATensor_RelaOp_bV_VS_gt = 57, - OP_FUNC_CGRATensor_RelaOp_V_VuV_gt = 58, - OP_FUNC_CGRATensor_RelaOp_V_VuV_gt_loop = 59, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt = 60, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt_loop = 61, - - OP_FUNC_CGRATensor_RelaOp_V_VV_le = 62, - OP_FUNC_CGRATensor_RelaOp_bV_VV_le = 63, - OP_FUNC_CGRATensor_RelaOp_V_VS_le = 64, - OP_FUNC_CGRATensor_RelaOp_bV_VS_le = 65, - OP_FUNC_CGRATensor_RelaOp_V_VuV_le = 66, - OP_FUNC_CGRATensor_RelaOp_V_VuV_le_loop = 67, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_le = 68, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_le_loop = 69, - - OP_FUNC_CGRATensor_RelaOp_V_VV_lt = 70, - OP_FUNC_CGRATensor_RelaOp_bV_VV_lt = 71, - OP_FUNC_CGRATensor_RelaOp_V_VS_lt = 72, - OP_FUNC_CGRATensor_RelaOp_bV_VS_lt = 73, - OP_FUNC_CGRATensor_RelaOp_V_VuV_lt = 74, - OP_FUNC_CGRATensor_RelaOp_V_VuV_lt_loop = 75, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt = 76, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt_loop = 77, - - OP_FUNC_CGRATensor_LogicOp_V_V_not = 78, - OP_FUNC_CGRATensor_LogicOp_V_VV_and = 79, - OP_FUNC_CGRATensor_LogicOp_V_VV_or = 80, - OP_FUNC_CGRATensor_LogicOp_V_VV_xor = 81, - OP_FUNC_CGRATensor_LogicOp_V_VuV_and = 82, - OP_FUNC_CGRATensor_LogicOp_V_VuV_or = 83, - OP_FUNC_CGRATensor_LogicOp_V_VuV_xor = 84, - OP_FUNC_CGRATensor_LogicOp_V_VuV_and_loop = 85, - OP_FUNC_CGRATensor_LogicOp_V_VuV_or_loop = 86, - OP_FUNC_CGRATensor_LogicOp_V_VuV_xor_loop = 87, - - OP_FUNC_CGRATensor_LogicOp_bV_bV_not = 88, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_and = 89, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_or = 90, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_xor = 91, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and = 92, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or = 93, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor = 94, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and_loop = 95, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or_loop = 96, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor_loop = 97, - - // Transcendental Operator - OP_FUNC_CGRATensor_TransOp_V_V_log2 = 98, - OP_FUNC_CGRATensor_TransOp_V_V_ln = 99, - OP_FUNC_CGRATensor_TransOp_V_V_pow2 = 100, - OP_FUNC_CGRATensor_TransOp_V_V_exp = 101, - OP_FUNC_CGRATensor_TransOp_V_V_exp_lp = 102, - OP_FUNC_CGRATensor_TransOp_V_V_sin = 103, - OP_FUNC_CGRATensor_TransOp_V_V_cos = 104, - - // Activation Operator - OP_FUNC_CGRATensor_ActOp_V_V_tanh = 105, - OP_FUNC_CGRATensor_ActOp_V_V_sigmoid = 106, - OP_FUNC_CGRATensor_ActOp_V_V_relu = 107, - OP_FUNC_CGRATensor_ActOp_V_V_satrelu = 108, - OP_FUNC_CGRATensor_ActOp_V_V_leakyrelu = 109, - OP_FUNC_CGRATensor_ActOp_V_V_softplus = 110, - - // Reduce Operator - OP_FUNC_CGRATensor_ReduceOp_T_T_sum = 111, - OP_FUNC_CGRATensor_ReduceOp_T_T_avg = 112, - OP_FUNC_CGRATensor_ReduceOp_T_T_max = 113, - OP_FUNC_CGRATensor_ReduceOp_T_T_min = 114, - - // Pool Operator - OP_FUNC_CGRATensor_PoolOp_T_T_avg = 115, - OP_FUNC_CGRATensor_PoolOp_T_T_sum = 116, - OP_FUNC_CGRATensor_PoolOp_T_T_max = 117, - OP_FUNC_CGRATensor_PoolOp_T_T_indexedmax = 118, - OP_FUNC_CGRATensor_PoolOp_T_T_min = 119, - OP_FUNC_CGRATensor_PoolOp_T_T_indexedmin = 120, - - // DataMove - OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool = 121, - OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool_avg = 122, - OP_FUNC_CGRATensor_DataMoveOp_T_T_maskunpool = 123, - // reshape - OP_FUNC_CGRATensor_DataMoveOp_T_T_mirror = 124, - OP_FUNC_CGRATensor_DataMoveOp_T_T_transpose = 125, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate90 = 126, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate180 = 127, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate270 = 128, - OP_FUNC_CGRATensor_DataMoveOp_T_T_nchw2nhwc = 129, - OP_FUNC_CGRATensor_DataMoveOp_T_T_nhwc2nchw = 130, - OP_FUNC_CGRATensor_DataMoveOp_T_T_concat = 131, - OP_FUNC_CGRATensor_DataMoveOp_T_T_pad = 132, - OP_FUNC_CGRATensor_DataMoveOp_T_T_channelnorm = 133, - // datamove - OP_FUNC_CGRATensor_DataMoveOp_V_V_maskmove = 134, - OP_FUNC_CGRATensor_DataMoveOp_T_T_gatherscatter = 135, - OP_FUNC_CGRATensor_DataMoveOp_V_V_maskgather = 136, - OP_FUNC_CGRATensor_DataMoveOp_V_bV_maskgather = 137, - OP_FUNC_CGRATensor_DataMoveOp_T_T_img2col = 138, - - // Conver Operator - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp16 = 139, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_bf16 = 140, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp32 = 141, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_tf32 = 142, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp16 = 143, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_bf16 = 144, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp32 = 145, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_tf32 = 146, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp16 = 147, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_bf16 = 148, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp32 = 149, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_tf32 = 150, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int8 = 151, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int16 = 152, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int32 = 153, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp16 = 154, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp32 = 155, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_tf32 = 156, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int8 = 157, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int16 = 158, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int32 = 159, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_bf16 = 160, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_fp32 = 161, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_tf32 = 162, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int8 = 163, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int16 = 164, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int32 = 165, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_fp16 = 166, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_bf16 = 167, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_tf32 = 168, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int8 = 169, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int16 = 170, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int32 = 171, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp16 = 172, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_bf16 = 173, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp32 = 174, - - // Peripheral Operator - OP_FUNC_CGRATensor_PeriOp_S_V_count = 175, - OP_FUNC_CGRATensor_PeriOp_S_bV_bitcount = 176, - OP_FUNC_CGRATensor_PeriOp_V_V_argmax = 177, - OP_FUNC_CGRATensor_PeriOp_V_V_argmin = 178, - OP_FUNC_CGRATensor_PeriOp_T_memset = 179, - OP_FUNC_CGRATensor_PeriOp_V_V_fp32_factorize = 180, - OP_FUNC_CGRATensor_PeriOp_V_V_bit2fp = 181, - OP_FUNC_CGRATensor_PeriOp_T_T_bilinear = 182, - OP_FUNC_CGRATensor_PeriOp_V_V_lut16 = 183, - OP_FUNC_CGRATensor_PeriOp_V_V_lut32 = 184, - OP_FUNC_CGRATensor_PeriOp_V_rand_gen = 185, - OP_FUNC_CGRATensor_PeriOp_V_V_elem_mask = 186, + // Arithmetic Operators + OP_FUNC_CGRATensor_ArithOp_V_V_abs = 0, + OP_FUNC_CGRATensor_ArithOp_V_V_recip = 1, + OP_FUNC_CGRATensor_ArithOp_V_V_square = 2, + OP_FUNC_CGRATensor_ArithOp_V_V_sqrt = 3, + OP_FUNC_CGRATensor_ArithOp_V_V_rsqrt = 4, + OP_FUNC_CGRATensor_ArithOp_V_V_neg = 5, + OP_FUNC_CGRATensor_ArithOp_V_VV_max = 6, + OP_FUNC_CGRATensor_ArithOp_V_VS_max = 7, + OP_FUNC_CGRATensor_ArithOp_V_VuV_max = 8, + OP_FUNC_CGRATensor_ArithOp_V_VuV_max_loop = 9, + OP_FUNC_CGRATensor_ArithOp_V_VV_min = 10, + OP_FUNC_CGRATensor_ArithOp_V_VS_min = 11, + OP_FUNC_CGRATensor_ArithOp_V_VuV_min = 12, + OP_FUNC_CGRATensor_ArithOp_V_VuV_min_loop = 13, + OP_FUNC_CGRATensor_ArithOp_V_VV_add = 14, + OP_FUNC_CGRATensor_ArithOp_V_VS_add = 15, + OP_FUNC_CGRATensor_ArithOp_V_VuV_add = 16, + OP_FUNC_CGRATensor_ArithOp_V_VuV_add_loop = 17, + OP_FUNC_CGRATensor_ArithOp_V_VV_sub = 18, + OP_FUNC_CGRATensor_ArithOp_V_VS_sub = 19, + OP_FUNC_CGRATensor_ArithOp_V_VuV_sub = 20, + OP_FUNC_CGRATensor_ArithOp_V_VuV_sub_loop = 21, + OP_FUNC_CGRATensor_ArithOp_V_VV_mul = 22, + OP_FUNC_CGRATensor_ArithOp_V_VS_mul = 23, + OP_FUNC_CGRATensor_ArithOp_V_VuV_mul = 24, + OP_FUNC_CGRATensor_ArithOp_V_VuV_mul_loop = 25, + OP_FUNC_CGRATensor_ArithOp_V_VV_div = 26, + OP_FUNC_CGRATensor_ArithOp_V_VS_div = 27, + OP_FUNC_CGRATensor_ArithOp_V_VuV_div = 28, + OP_FUNC_CGRATensor_ArithOp_V_VuV_div_loop = 29, + + // Relational Operators + OP_FUNC_CGRATensor_RelaOp_V_VV_eq = 30, + OP_FUNC_CGRATensor_RelaOp_bV_VV_eq = 31, + OP_FUNC_CGRATensor_RelaOp_V_VS_eq = 32, + OP_FUNC_CGRATensor_RelaOp_bV_VS_eq = 33, + OP_FUNC_CGRATensor_RelaOp_V_VuV_eq = 34, + OP_FUNC_CGRATensor_RelaOp_V_VuV_eq_loop = 35, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq = 36, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq_loop = 37, + + OP_FUNC_CGRATensor_RelaOp_V_VV_ne = 38, + OP_FUNC_CGRATensor_RelaOp_bV_VV_ne = 39, + OP_FUNC_CGRATensor_RelaOp_V_VS_ne = 40, + OP_FUNC_CGRATensor_RelaOp_bV_VS_ne = 41, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ne = 42, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ne_loop = 43, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne = 44, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne_loop = 45, + + OP_FUNC_CGRATensor_RelaOp_V_VV_ge = 46, + OP_FUNC_CGRATensor_RelaOp_bV_VV_ge = 47, + OP_FUNC_CGRATensor_RelaOp_V_VS_ge = 48, + OP_FUNC_CGRATensor_RelaOp_bV_VS_ge = 49, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ge = 50, + OP_FUNC_CGRATensor_RelaOp_V_VuV_ge_loop = 51, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge = 52, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge_loop = 53, + + OP_FUNC_CGRATensor_RelaOp_V_VV_gt = 54, + OP_FUNC_CGRATensor_RelaOp_bV_VV_gt = 55, + OP_FUNC_CGRATensor_RelaOp_V_VS_gt = 56, + OP_FUNC_CGRATensor_RelaOp_bV_VS_gt = 57, + OP_FUNC_CGRATensor_RelaOp_V_VuV_gt = 58, + OP_FUNC_CGRATensor_RelaOp_V_VuV_gt_loop = 59, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt = 60, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt_loop = 61, + + OP_FUNC_CGRATensor_RelaOp_V_VV_le = 62, + OP_FUNC_CGRATensor_RelaOp_bV_VV_le = 63, + OP_FUNC_CGRATensor_RelaOp_V_VS_le = 64, + OP_FUNC_CGRATensor_RelaOp_bV_VS_le = 65, + OP_FUNC_CGRATensor_RelaOp_V_VuV_le = 66, + OP_FUNC_CGRATensor_RelaOp_V_VuV_le_loop = 67, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_le = 68, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_le_loop = 69, + + OP_FUNC_CGRATensor_RelaOp_V_VV_lt = 70, + OP_FUNC_CGRATensor_RelaOp_bV_VV_lt = 71, + OP_FUNC_CGRATensor_RelaOp_V_VS_lt = 72, + OP_FUNC_CGRATensor_RelaOp_bV_VS_lt = 73, + OP_FUNC_CGRATensor_RelaOp_V_VuV_lt = 74, + OP_FUNC_CGRATensor_RelaOp_V_VuV_lt_loop = 75, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt = 76, + OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt_loop = 77, + + OP_FUNC_CGRATensor_LogicOp_V_V_not = 78, + OP_FUNC_CGRATensor_LogicOp_V_VV_and = 79, + OP_FUNC_CGRATensor_LogicOp_V_VV_or = 80, + OP_FUNC_CGRATensor_LogicOp_V_VV_xor = 81, + OP_FUNC_CGRATensor_LogicOp_V_VuV_and = 82, + OP_FUNC_CGRATensor_LogicOp_V_VuV_or = 83, + OP_FUNC_CGRATensor_LogicOp_V_VuV_xor = 84, + OP_FUNC_CGRATensor_LogicOp_V_VuV_and_loop = 85, + OP_FUNC_CGRATensor_LogicOp_V_VuV_or_loop = 86, + OP_FUNC_CGRATensor_LogicOp_V_VuV_xor_loop = 87, + + OP_FUNC_CGRATensor_LogicOp_bV_bV_not = 88, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_and = 89, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_or = 90, + OP_FUNC_CGRATensor_LogicOp_bV_bVbV_xor = 91, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and = 92, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or = 93, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor = 94, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and_loop = 95, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or_loop = 96, + OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor_loop = 97, + + // Transcendental Operator + OP_FUNC_CGRATensor_TransOp_V_V_log2 = 98, + OP_FUNC_CGRATensor_TransOp_V_V_ln = 99, + OP_FUNC_CGRATensor_TransOp_V_V_pow2 = 100, + OP_FUNC_CGRATensor_TransOp_V_V_exp = 101, + OP_FUNC_CGRATensor_TransOp_V_V_exp_lp = 102, + OP_FUNC_CGRATensor_TransOp_V_V_sin = 103, + OP_FUNC_CGRATensor_TransOp_V_V_cos = 104, + + // Activation Operator + OP_FUNC_CGRATensor_ActOp_V_V_tanh = 105, + OP_FUNC_CGRATensor_ActOp_V_V_sigmoid = 106, + OP_FUNC_CGRATensor_ActOp_V_V_relu = 107, + OP_FUNC_CGRATensor_ActOp_V_V_satrelu = 108, + OP_FUNC_CGRATensor_ActOp_V_V_leakyrelu = 109, + OP_FUNC_CGRATensor_ActOp_V_V_softplus = 110, + + // Reduce Operator + OP_FUNC_CGRATensor_ReduceOp_T_T_sum = 111, + OP_FUNC_CGRATensor_ReduceOp_T_T_avg = 112, + OP_FUNC_CGRATensor_ReduceOp_T_T_max = 113, + OP_FUNC_CGRATensor_ReduceOp_T_T_min = 114, + + // Pool Operator + OP_FUNC_CGRATensor_PoolOp_T_T_avg = 115, + OP_FUNC_CGRATensor_PoolOp_T_T_sum = 116, + OP_FUNC_CGRATensor_PoolOp_T_T_max = 117, + OP_FUNC_CGRATensor_PoolOp_T_T_indexedmax = 118, + OP_FUNC_CGRATensor_PoolOp_T_T_min = 119, + OP_FUNC_CGRATensor_PoolOp_T_T_indexedmin = 120, + + // DataMove + OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool = 121, + OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool_avg = 122, + OP_FUNC_CGRATensor_DataMoveOp_T_T_maskunpool = 123, + // reshape + OP_FUNC_CGRATensor_DataMoveOp_T_T_mirror = 124, + OP_FUNC_CGRATensor_DataMoveOp_T_T_transpose = 125, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate90 = 126, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate180 = 127, + OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate270 = 128, + OP_FUNC_CGRATensor_DataMoveOp_T_T_nchw2nhwc = 129, + OP_FUNC_CGRATensor_DataMoveOp_T_T_nhwc2nchw = 130, + OP_FUNC_CGRATensor_DataMoveOp_T_T_concat = 131, + OP_FUNC_CGRATensor_DataMoveOp_T_T_pad = 132, + OP_FUNC_CGRATensor_DataMoveOp_T_T_channelnorm = 133, + // datamove + OP_FUNC_CGRATensor_DataMoveOp_V_V_maskmove = 134, + OP_FUNC_CGRATensor_DataMoveOp_T_T_gatherscatter = 135, + OP_FUNC_CGRATensor_DataMoveOp_V_V_maskgather = 136, + OP_FUNC_CGRATensor_DataMoveOp_V_bV_maskgather = 137, + OP_FUNC_CGRATensor_DataMoveOp_T_T_img2col = 138, + + // Conver Operator + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp16 = 139, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_bf16 = 140, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp32 = 141, + OP_FUNC_CGRATensor_ConvertOp_V_V_int8_tf32 = 142, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp16 = 143, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_bf16 = 144, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp32 = 145, + OP_FUNC_CGRATensor_ConvertOp_V_V_int16_tf32 = 146, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp16 = 147, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_bf16 = 148, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp32 = 149, + OP_FUNC_CGRATensor_ConvertOp_V_V_int32_tf32 = 150, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int8 = 151, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int16 = 152, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int32 = 153, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp16 = 154, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp32 = 155, + OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_tf32 = 156, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int8 = 157, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int16 = 158, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int32 = 159, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_bf16 = 160, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_fp32 = 161, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_tf32 = 162, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int8 = 163, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int16 = 164, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int32 = 165, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_fp16 = 166, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_bf16 = 167, + OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_tf32 = 168, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int8 = 169, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int16 = 170, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int32 = 171, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp16 = 172, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_bf16 = 173, + OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp32 = 174, + + // Peripheral Operator + OP_FUNC_CGRATensor_PeriOp_S_V_count = 175, + OP_FUNC_CGRATensor_PeriOp_S_bV_bitcount = 176, + OP_FUNC_CGRATensor_PeriOp_V_V_argmax = 177, + OP_FUNC_CGRATensor_PeriOp_V_V_argmin = 178, + OP_FUNC_CGRATensor_PeriOp_T_memset = 179, + OP_FUNC_CGRATensor_PeriOp_V_V_fp32_factorize = 180, + OP_FUNC_CGRATensor_PeriOp_V_V_bit2fp = 181, + OP_FUNC_CGRATensor_PeriOp_T_T_bilinear = 182, + OP_FUNC_CGRATensor_PeriOp_V_V_lut16 = 183, + OP_FUNC_CGRATensor_PeriOp_V_V_lut32 = 184, + OP_FUNC_CGRATensor_PeriOp_V_rand_gen = 185, + OP_FUNC_CGRATensor_PeriOp_V_V_elem_mask = 186, } OP_FUNC_CGRA; typedef enum CGRA_INSTR_TYPE { - CGRA_INSTR_TYPE0, - CGRA_INSTR_TYPE1, - CGRA_INSTR_TYPE2, - CGRA_INSTR_TYPE3, + CGRA_INSTR_TYPE0, + CGRA_INSTR_TYPE1, + CGRA_INSTR_TYPE2, + CGRA_INSTR_TYPE3, } CGRA_INSTR_TYPE; typedef struct Op_fu_head { - uint8_t fu; - uint8_t opcode; + uint8_t fu; + uint8_t opcode; } Op_fu_head; typedef struct FU_gemm_head { - uint8_t fu; - uint8_t gemm; + uint8_t fu; + uint8_t gemm; } FU_gemm_head; typedef struct opfunc_cgra_info { - char name[64]; // CGRATensor_ArithOp_V_V_abs - int32_t opcode; // 8'b0000_0000 - int32_t type; // CGRA_Tensor_type0 + char name[64]; // CGRATensor_ArithOp_V_V_abs + int32_t opcode; // 8'b0000_0000 + int32_t type; // CGRA_Tensor_type0 } opfunc_cgra_info; // Neural typedef enum Data_Format { - Fmt_INT8, - Fmt_INT16, - Fmt_FP16, - Fmt_BF16, - Fmt_INT32, - Fmt_FP32, - Fmt_TF32, - Fmt_BOOL, // 1/8 BYTE - Fmt_UINT8, - Fmt_UINT16, - Fmt_UINT32, - Fmt_INT64, - Fmt_UINT64, - Fmt_UNUSED, + Fmt_INT8, + Fmt_INT16, + Fmt_FP16, + Fmt_BF16, + Fmt_INT32, + Fmt_FP32, + Fmt_TF32, + Fmt_BOOL, // 1/8 BYTE + Fmt_UINT8, + Fmt_UINT16, + Fmt_UINT32, + Fmt_INT64, + Fmt_UINT64, + Fmt_UNUSED, } Data_Format; typedef enum Tensor_Fmt { - T_GemmM = 0, /*M K*/ - T_ConvA = 1, /*H W C*/ - T_ConvW = 2, /*Kx Ky F C*/ - T_Vec = 3, - T_ConvNA = 4, - T_ConvNW = 5, + T_GemmM = 0, /*M K*/ + T_ConvA = 1, /*H W C*/ + T_ConvW = 2, /*Kx Ky F C*/ + T_Vec = 3, + T_ConvNA = 4, + T_ConvNW = 5, } Tensor_Fmt; /* @@ -901,24 +920,30 @@ typedef enum Tensor_Fmt { HW方向规约,结果为(H=1)(W=1)C,dim=4 */ typedef enum Reduce_Dim { - Reduce_C = 0, - Reduce_W = 1, - Reduce_H = 2, - Reduce_HW = 4, + Reduce_C = 0, + Reduce_W = 1, + Reduce_H = 2, + Reduce_HW = 4, } Reduce_Dim; typedef struct NCC_CSR { - uint64_t ib_status; //[7:0]IB_COUNTER: 指令buffer剩余指令数目, [8]TASK_DONE, 1:task执行结束, 0:task 正在执行, - //[63:9]Reserved - uint64_t exception; //[7:0]SCALAR_EXCEPTION, [15:8]CT_EXCEPTION, [23:16]NE_EXCEPTION, [31:24]RDMA_EXCEPTION, - //[39:32]WDMA_EXCEPTION, [47:40]TDMA_EXCEPTION, [63:48]Reserved - uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved - uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, [49]EXCEPTION_CLEAR, [63:49]Reserved - uint64_t serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved + uint64_t ib_status; //[7:0]IB_COUNTER: 指令buffer剩余指令数目, [8]TASK_DONE, + // 1:task执行结束, 0:task 正在执行, [63:9]Reserved + uint64_t + exception; //[7:0]SCALAR_EXCEPTION, [15:8]CT_EXCEPTION, + //[23:16]NE_EXCEPTION, [31:24]RDMA_EXCEPTION, + //[39:32]WDMA_EXCEPTION, [47:40]TDMA_EXCEPTION, [63:48]Reserved + uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved + uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, + //[49]EXCEPTION_CLEAR, [63:49]Reserved + uint64_t + serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved } NCC_CSR; typedef struct EXCEP_SERI { - uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, [49]EXCEPTION_CLEAR, [63:49]Reserved - uint64_t serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved + uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, + //[49]EXCEPTION_CLEAR, [63:49]Reserved + uint64_t + serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved } EXCEP_SERI; #endif diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h index c1bf4bb2f..9aed4ed61 100644 --- a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h +++ b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h @@ -15,460 +15,474 @@ * */ - #ifndef __HOST_RUNTIME_COM_H__ - #define __HOST_RUNTIME_COM_H__ - - #include - #include - #include - #include - #include - #include - #include - #include - - #ifndef MAX_SHAPE_DIM - #define MAX_SHAPE_DIM 6 - #endif - - #ifndef MAX_MODEL_NUM - #define MAX_MODEL_NUM 32 - #endif - - typedef uint64_t TsmDevicePtr; - typedef uint64_t TsmHostPtr; - - #define CHIP_MAX_NUM 32 - #define TILE_MAX_NUM 16 - #define CACHE_ALIGN_4k 4096 - - typedef void *(*THREAD_PROC_FUNC)(void *); - - enum TSM_RETCODE { - RET_SUCCESS, - RET_ERROR, - RET_PARAM1_ERROR, - RET_PARAM2_ERROR, - RET_PARAM3_ERROR, - RET_DEVICE_OFFLINE, - RET_DEVICE_NOMEM, - RET_DEVICE_IN_IDLE, - RET_DEVICE_IN_ATTACH, - RET_DEVICE_ATTACH_SUCCESS, - RET_DEVICE_ATTACH_READY, - RET_DEVICE_LOSE_CONNECT, - RET_ENV_CLEAN_UP, - }; - - typedef enum HostLogLevel { LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL, LOG_MAX } HostLogLevel; - - typedef enum TsmModuleType { - TSM_RUNTIME, - TSM_XLA, // 前端 - TSM_TXNN, // 推理引擎 - TSM_ENGTEST, // 板端测试套件 - TSM_HOSTSIM, // 模拟器测试套件 - TSM_CMODEL, // 模拟器API - TSM_RT_TEST, // runtime组件测试套件 - } TsmModuleType; - - typedef enum TsmProfAction { TSM_PROF_START, TSM_PROF_STOP } TsmProfAction; - constexpr uint16_t PROF_TYPE_NCC = 0x1; - constexpr uint16_t PROF_TYPE_SPM = 0x2; - constexpr uint16_t PROF_TYPE_DTE = 0x4; - - typedef enum DTYPE { - FMT_INT8, - FMT_INT16, - FMT_FP16, - FMT_BF16, - FMT_INT32, - FMT_FP32, - FMT_TF32, - FMT_BOOL, // 1/8 BYTE - FMT_UINT8, - FMT_UINT16, - FMT_UINT32, - FMT_INT64, - FMT_UINT64, - FMT_UNUSED, - } DTYPE; - - uint8_t hrt_get_dtype_size(DTYPE dtype); - - enum DynDataType { - PKT_FINAL_TYPE = 0, - CFG_PMU_TYPE, - KCORE_CFG_TYPE, - EXPORT_SPM_TYPE, - DISABLE_CALC_TYPE, - PROF_CFG_TYPE, - DYNLIB_LOAD, - DYNLIB_RUN, - DYNLIB_UNLOAD, - MEMCPY_D2D, - P2P_SEND, - P2P_RECV, - DATA_TYPE_MAX, - }; - - typedef struct DynTLV_Terminate { - uint32_t type; // DynDataType - uint32_t len; - uint64_t is_final; - } DynTLV_Terminate; - - typedef struct DynTLV { - uint32_t type; // DynDataType - uint32_t len; - } DynTLV; - - typedef struct Cfg_Pmu_Info { - uint32_t tile_bitmap[16]; - uint32_t mac_use_rate; - uint32_t chip_id; - uint32_t cycles; - uint64_t in_ddr; - uint64_t param_ddr; - uint64_t out_ddr; - uint32_t reserved; - } Cfg_Pmu_Info; - - typedef struct DynTLV_Cfgpmu { - uint32_t type; // DynDataType - uint32_t len; - Cfg_Pmu_Info cfg_pmu; - } DynTLV_Cfgpmu; - - typedef struct DynTLV_KcoreCfg { - uint32_t type; - uint32_t len; - uint64_t snap_addr[TILE_MAX_NUM]; - uint64_t console_addr[TILE_MAX_NUM]; - uint64_t spm_dump_addr[TILE_MAX_NUM]; - uint64_t spm_dump_size; - uint32_t log_level; - uint32_t enable_monitor; - } DynTLV_KcoreCfg; - - typedef struct DynTLV_KcoreCalc { - uint32_t type; - uint32_t len; - uint32_t disable_kcore_calc; - } DynTLV_KcoreCalc; - - typedef struct DynTLV_ProfCfg { - uint32_t type; - uint32_t len; - uint64_t addrs[TILE_MAX_NUM]; - uint32_t size; - uint16_t enable; - uint16_t prof_type; - } DynTLV_ProfCfg; - - // #define TILE_NUM 16 - typedef struct DynModule { - char module_name[128]; - char module_symbol[128]; // typedef void (*entry_func_t)(voicd *): - uint32_t module_size[TILE_MAX_NUM]; - uint64_t module_addr[TILE_MAX_NUM]; // dev地址 - } DynModule; - - typedef struct DynMods { - uint16_t module_num; - struct DynModule modules[0]; - } DynMods; // host共用结构,传过来这个首地址 - - typedef struct DynTLV_DynMods { - uint32_t type; // DynDataType - uint32_t len; - uint64_t ext_addr; - uint64_t dyn_mods_addr; //指向DynMods - } DynTLV_DynMods; - - typedef struct TileDteCfg { - uint16_t status; // 该tile是否参与搬运工作 - uint16_t remote_tile_id; // 对端tile_id - uint32_t element_count; // 单次搬运cache_line大小,默认4k - uint32_t stride; // 步长 - uint32_t left_element_count; // 搬完cache_line后,剩余的搬运的长度 - uint64_t iteration; // 搬运cache_line的次数 - uint64_t src_addr; // 搬运cache_line的源地址 - 物理 - uint64_t dst_addr; // 搬运cache_line的目的地址 - 物理 - uint64_t left_src_addr; // 搬运余数的源地址 - 物理 - uint64_t left_dst_addr; // 搬运余数的目的地址 - 物理 - } TileDteCfg; - typedef struct DynTLV_DteCfg { - uint32_t type; - uint32_t len; - TileDteCfg tile_dte_cfg[TILE_MAX_NUM]; - uint64_t barrier_addr; - uint32_t row_card_num; - uint32_t reserved; - } DynTLV_DteCfg; - - enum Tensor_Type { - INPUT_DATA, - OUTPUT_DATA, - PARAM_DATA, - CHACHE_DATA, - DEV_DDR_DATA, - }; - - typedef struct tensor_info { - int32_t inplace; - uint32_t dim; - uint32_t dtype; - uint32_t layout; - uint32_t shape[MAX_SHAPE_DIM]; - } tensor_info_t; - - typedef struct Json_common_info_t { - uint32_t input_num; - uint32_t output_num; - uint32_t param_num; - uint32_t tile_num; - - std::string case_name; - std::string card_name; - - std::vector> input; - std::vector> output; - - std::vector input_file; - std::vector output_file; - std::vector param_file; - - std::vector input_size; - std::vector output_size; - std::vector param_size; - uint64_t imm_size; - - } Json_common_info_t; - - typedef struct chip_common_info { - uint32_t input_num; - uint32_t output_num; - uint32_t param_num; - uint32_t tile_num; - uint32_t tile_x; - uint32_t tile_y; - std::vector> input; - std::vector> output; - - // char card_name[100]; - std::string card_name; - std::vector input_file; - std::vector output_file; - std::vector output_ref_file; - std::vector param_file; - - std::vector input_size; - std::vector output_size; - std::vector param_size; - - std::vector input_host_addr; - std::vector input_dev_addr; - std::vector output_host_addr; - std::vector output_dev_addr; - std::vector param_host_addr; - std::vector param_dev_addr; - - uint64_t imm_size; - } chip_common_info_t; - - typedef struct json_common_info_multi_card { - uint32_t chip_num; - uint32_t chip_x; - uint32_t chip_y; - std::string case_name; - uint32_t loop_num; - std::vector> chip_infos; - } json_common_info_multi_card_t; - - typedef struct CompileOption { - bool comp_enable = false; - std::string rtt_tool_path; - std::string compile_path; - bool check_enable = false; - uint32_t chip_x; - uint32_t chip_y; - bool enable_kcore_bin; - bool enable_kcore_so; - } CompileOption; - - // Boot Param Table - typedef struct BootParamHead { - uint32_t MaxLen; // BootParamHead + n * BootParamDyninfo, n = inputnum + outputnum + paramnum - uint32_t LdmemLen; - uint32_t InputNum; - uint32_t OutputNum; - uint32_t ParamNum; - uint32_t reserved; - uint64_t CacheMemLen; - uint64_t CacheMemAddr; // device - uint32_t Datalen; - uint32_t reserved1; - uint64_t DataAddr; // device - } BootParamHead; - - typedef struct BootParamDyninfo { - uint64_t addr; // device - uint64_t size; - uint32_t dtype; - uint32_t dim; - uint32_t shape[6]; //#define MAX_SHAPE_DIM 6 //n, h, w, c, x, x - } BootParamDyninfo; - - class HrtBootParam { - public: - HrtBootParam(uint32_t i_num, uint32_t o_num, uint32_t p_num) : i_num(i_num), o_num(o_num), p_num(p_num) { - uint32_t bufsize = (sizeof(BootParamHead) + (i_num + o_num + 1) * sizeof(BootParamDyninfo)); - buffer = (void *)malloc(bufsize); - memset(buffer, 0, bufsize); - BootParamHead *head = static_cast(buffer); - head->MaxLen = bufsize; - head->LdmemLen = 0x200000; - head->InputNum = i_num; - head->OutputNum = o_num; - head->ParamNum = p_num; - } - ~HrtBootParam() { - if (buffer != nullptr) { - free(buffer); - } - } - std::vector dyninfo; - uint32_t get_maxlen(); - void *get_bootpmbuffer(); - BootParamHead *get_headptr(); - BootParamDyninfo *get_inputptr(uint32_t index); - BootParamDyninfo *get_outputptr(uint32_t index); - BootParamDyninfo *get_paramptr(uint32_t index); - void set_dev_cache(uint64_t dev_addr, uint64_t size); - void set_dev_cache_mem_addr(uint64_t dev_addr, uint64_t size); - void set_dev_dyndata(uint64_t dev_addr, uint32_t size); - void set_dev_dyndata_mem_addr(uint64_t dev_addr, uint32_t size); - void set_dev_input(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_input_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_input_tensor(uint32_t idx, std::shared_ptr tensor); - void set_dev_output(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_output_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_param(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_param_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - std::shared_ptr get_dev_output_tensor_after_run(uint32_t idx); - - private: - uint32_t i_num; - uint32_t o_num; - uint32_t p_num; - void *buffer; - }; - /* 启动参数end */ - - /* compiler后生成的存储elf和param地址的对象 */ - class HostParamElem { - public: - HostParamElem() : dataPtr(nullptr), size(0) {} - ~HostParamElem(); - //模拟器:从文件中加载一个bin - HostParamElem(const std::string &filepath); - - uint8_t *loadBinaryFile(const std::string filepath, uint64_t &fsize); - uint8_t *dataPtr; // host - uint64_t size; // byte - }; - - class ChipModelInfo { - public: - ChipModelInfo(); - ChipModelInfo(uint32_t id); - ~ChipModelInfo(); - - uint32_t getChipId() { return chip_id; } - // support multi chip - std::vector> elfs; //编译出的elf文件 - std::vector> bins; //编译出的bin文件 - std::vector> params; - - private: - uint32_t chip_id; - }; - - /* - * compiler后生成的模型对象,在launch的时候会将elf/bin的指针传入soc的接口, - * (PCIE搬运时,如果空间不连续会触发多次搬运,因此交由SOC组装连续空间。) - */ - class TsmModel { - public: - TsmModel(); // org_model - ~TsmModel(); - TsmModel(const std::string &filepath); - - std::vector> chip_infos; - THREAD_PROC_FUNC proc_func; - std::string case_name; - std::string case_dir; - std::shared_ptr so_list[MAX_MODEL_NUM][TILE_MAX_NUM]; // 编译出的so文件 - std::string module_name; - struct txmodel *model[MAX_MODEL_NUM]; - }; - - typedef struct TsmDevice { - char res_path[128]; - uint32_t chip_id; - uint32_t tile_num = 16; - void *soc_device; - } TsmDevice_t; - - class TsmTensorData { - public: - TsmTensorData() : host_addr(0), device_addr(0), length(0) {} - ~TsmTensorData(){}; - - TsmHostPtr host_addr; - TsmDevicePtr device_addr; - uint64_t length; - uint32_t data_type; - Tensor_Type tensor_type; - }; - - typedef void *tsmStream_t; - typedef void *tsmEvent_t; - typedef struct txcclComm* txcclComm_t; - typedef enum { - txcclDataDefault = 0 - } txcclDataType_t; // 预留,待讨论 - - enum device_status { - FULLGOOD = 0, - PARTIALGOOD = 1, - }; - - constexpr uint32_t PARTIALGOOD_NUM = 8; - constexpr uint32_t FULLGOOD_NUM = 16; - - struct CardComputeInfo { - uint32_t card_id; - enum device_status device_status; - uint32_t all_tile_num; - double all_tile_compute; - uint32_t left_tile_num; - double left_tile_compute; - }; - - struct TsmDeviceInfo { - uint32_t card_num; - uint32_t card_x; - uint32_t card_y; - CardComputeInfo card_compute_info[CHIP_MAX_NUM]; - }; - - int32_t readDataFromFile(uint8_t *buffer, std::string file, uint32_t size); - uint8_t *read_file_data(std::string file, uint64_t &size); - - std::shared_ptr get_multi_card_common_info_from_file(std::string file); - std::string get_docker_verison(); - TSM_RETCODE set_multi_graph(TsmModel *&kmodel, std::shared_ptr &hostboot, - const TsmDevicePtr &dev_dyn_mods_ptr, const TsmDevicePtr &dev_tlv_ptr, TsmDevicePtr ext_ptr); - #endif /* __HOST_RUNTIME_COM_H__ */ +#ifndef __HOST_RUNTIME_COM_H__ +#define __HOST_RUNTIME_COM_H__ + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef MAX_SHAPE_DIM +#define MAX_SHAPE_DIM 6 +#endif + +#ifndef MAX_MODEL_NUM +#define MAX_MODEL_NUM 32 +#endif + +typedef uint64_t TsmDevicePtr; +typedef uint64_t TsmHostPtr; + +#define CHIP_MAX_NUM 32 +#define TILE_MAX_NUM 16 +#define CACHE_ALIGN_4k 4096 + +typedef void *(*THREAD_PROC_FUNC)(void *); + +enum TSM_RETCODE { + RET_SUCCESS, + RET_ERROR, + RET_PARAM1_ERROR, + RET_PARAM2_ERROR, + RET_PARAM3_ERROR, + RET_DEVICE_OFFLINE, + RET_DEVICE_NOMEM, + RET_DEVICE_IN_IDLE, + RET_DEVICE_IN_ATTACH, + RET_DEVICE_ATTACH_SUCCESS, + RET_DEVICE_ATTACH_READY, + RET_DEVICE_LOSE_CONNECT, + RET_ENV_CLEAN_UP, +}; + +typedef enum HostLogLevel { + LOG_DEBUG, + LOG_INFO, + LOG_WARNING, + LOG_ERROR, + LOG_FATAL, + LOG_MAX +} HostLogLevel; + +typedef enum TsmModuleType { + TSM_RUNTIME, + TSM_XLA, // 前端 + TSM_TXNN, // 推理引擎 + TSM_ENGTEST, // 板端测试套件 + TSM_HOSTSIM, // 模拟器测试套件 + TSM_CMODEL, // 模拟器API + TSM_RT_TEST, // runtime组件测试套件 +} TsmModuleType; + +typedef enum TsmProfAction { TSM_PROF_START, TSM_PROF_STOP } TsmProfAction; +constexpr uint16_t PROF_TYPE_NCC = 0x1; +constexpr uint16_t PROF_TYPE_SPM = 0x2; +constexpr uint16_t PROF_TYPE_DTE = 0x4; + +typedef enum DTYPE { + FMT_INT8, + FMT_INT16, + FMT_FP16, + FMT_BF16, + FMT_INT32, + FMT_FP32, + FMT_TF32, + FMT_BOOL, // 1/8 BYTE + FMT_UINT8, + FMT_UINT16, + FMT_UINT32, + FMT_INT64, + FMT_UINT64, + FMT_UNUSED, +} DTYPE; + +uint8_t hrt_get_dtype_size(DTYPE dtype); + +enum DynDataType { + PKT_FINAL_TYPE = 0, + CFG_PMU_TYPE, + KCORE_CFG_TYPE, + EXPORT_SPM_TYPE, + DISABLE_CALC_TYPE, + PROF_CFG_TYPE, + DYNLIB_LOAD, + DYNLIB_RUN, + DYNLIB_UNLOAD, + MEMCPY_D2D, + P2P_SEND, + P2P_RECV, + DATA_TYPE_MAX, +}; + +typedef struct DynTLV_Terminate { + uint32_t type; // DynDataType + uint32_t len; + uint64_t is_final; +} DynTLV_Terminate; + +typedef struct DynTLV { + uint32_t type; // DynDataType + uint32_t len; +} DynTLV; + +typedef struct Cfg_Pmu_Info { + uint32_t tile_bitmap[16]; + uint32_t mac_use_rate; + uint32_t chip_id; + uint32_t cycles; + uint64_t in_ddr; + uint64_t param_ddr; + uint64_t out_ddr; + uint32_t reserved; +} Cfg_Pmu_Info; + +typedef struct DynTLV_Cfgpmu { + uint32_t type; // DynDataType + uint32_t len; + Cfg_Pmu_Info cfg_pmu; +} DynTLV_Cfgpmu; + +typedef struct DynTLV_KcoreCfg { + uint32_t type; + uint32_t len; + uint64_t snap_addr[TILE_MAX_NUM]; + uint64_t console_addr[TILE_MAX_NUM]; + uint64_t spm_dump_addr[TILE_MAX_NUM]; + uint64_t spm_dump_size; + uint32_t log_level; + uint32_t enable_monitor; +} DynTLV_KcoreCfg; + +typedef struct DynTLV_KcoreCalc { + uint32_t type; + uint32_t len; + uint32_t disable_kcore_calc; +} DynTLV_KcoreCalc; + +typedef struct DynTLV_ProfCfg { + uint32_t type; + uint32_t len; + uint64_t addrs[TILE_MAX_NUM]; + uint32_t size; + uint16_t enable; + uint16_t prof_type; +} DynTLV_ProfCfg; + +// #define TILE_NUM 16 +typedef struct DynModule { + char module_name[128]; + char module_symbol[128]; // typedef void (*entry_func_t)(voicd *): + uint32_t module_size[TILE_MAX_NUM]; + uint64_t module_addr[TILE_MAX_NUM]; // dev地址 +} DynModule; + +typedef struct DynMods { + uint16_t module_num; + struct DynModule modules[0]; +} DynMods; // host共用结构,传过来这个首地址 + +typedef struct DynTLV_DynMods { + uint32_t type; // DynDataType + uint32_t len; + uint64_t ext_addr; + uint64_t dyn_mods_addr; // 指向DynMods +} DynTLV_DynMods; + +typedef struct TileDteCfg { + uint16_t status; // 该tile是否参与搬运工作 + uint16_t remote_tile_id; // 对端tile_id + uint32_t element_count; // 单次搬运cache_line大小,默认4k + uint32_t stride; // 步长 + uint32_t left_element_count; // 搬完cache_line后,剩余的搬运的长度 + uint64_t iteration; // 搬运cache_line的次数 + uint64_t src_addr; // 搬运cache_line的源地址 - 物理 + uint64_t dst_addr; // 搬运cache_line的目的地址 - 物理 + uint64_t left_src_addr; // 搬运余数的源地址 - 物理 + uint64_t left_dst_addr; // 搬运余数的目的地址 - 物理 +} TileDteCfg; +typedef struct DynTLV_DteCfg { + uint32_t type; + uint32_t len; + TileDteCfg tile_dte_cfg[TILE_MAX_NUM]; + uint64_t barrier_addr; + uint32_t row_card_num; + uint32_t reserved; +} DynTLV_DteCfg; + +enum Tensor_Type { + INPUT_DATA, + OUTPUT_DATA, + PARAM_DATA, + CHACHE_DATA, + DEV_DDR_DATA, +}; + +typedef struct tensor_info { + int32_t inplace; + uint32_t dim; + uint32_t dtype; + uint32_t layout; + uint32_t shape[MAX_SHAPE_DIM]; +} tensor_info_t; + +typedef struct Json_common_info_t { + uint32_t input_num; + uint32_t output_num; + uint32_t param_num; + uint32_t tile_num; + + std::string case_name; + std::string card_name; + + std::vector> input; + std::vector> output; + + std::vector input_file; + std::vector output_file; + std::vector param_file; + + std::vector input_size; + std::vector output_size; + std::vector param_size; + uint64_t imm_size; + +} Json_common_info_t; + +typedef struct chip_common_info { + uint32_t input_num; + uint32_t output_num; + uint32_t param_num; + uint32_t tile_num; + uint32_t tile_x; + uint32_t tile_y; + std::vector> input; + std::vector> output; + + // char card_name[100]; + std::string card_name; + std::vector input_file; + std::vector output_file; + std::vector output_ref_file; + std::vector param_file; + + std::vector input_size; + std::vector output_size; + std::vector param_size; + + std::vector input_host_addr; + std::vector input_dev_addr; + std::vector output_host_addr; + std::vector output_dev_addr; + std::vector param_host_addr; + std::vector param_dev_addr; + + uint64_t imm_size; +} chip_common_info_t; + +typedef struct json_common_info_multi_card { + uint32_t chip_num; + uint32_t chip_x; + uint32_t chip_y; + std::string case_name; + uint32_t loop_num; + std::vector> chip_infos; +} json_common_info_multi_card_t; + +typedef struct CompileOption { + bool comp_enable = false; + std::string rtt_tool_path; + std::string compile_path; + bool check_enable = false; + uint32_t chip_x; + uint32_t chip_y; + bool enable_kcore_bin; + bool enable_kcore_so; +} CompileOption; + +// Boot Param Table +typedef struct BootParamHead { + uint32_t MaxLen; // BootParamHead + n * BootParamDyninfo, n = inputnum + + // outputnum + paramnum + uint32_t LdmemLen; + uint32_t InputNum; + uint32_t OutputNum; + uint32_t ParamNum; + uint32_t reserved; + uint64_t CacheMemLen; + uint64_t CacheMemAddr; // device + uint32_t Datalen; + uint32_t reserved1; + uint64_t DataAddr; // device +} BootParamHead; + +typedef struct BootParamDyninfo { + uint64_t addr; // device + uint64_t size; + uint32_t dtype; + uint32_t dim; + uint32_t shape[6]; // #define MAX_SHAPE_DIM 6 //n, h, w, c, x, x +} BootParamDyninfo; + +class HrtBootParam { +public: + HrtBootParam(uint32_t i_num, uint32_t o_num, uint32_t p_num) + : i_num(i_num), o_num(o_num), p_num(p_num) { + uint32_t bufsize = (sizeof(BootParamHead) + + (i_num + o_num + 1) * sizeof(BootParamDyninfo)); + buffer = (void *)malloc(bufsize); + memset(buffer, 0, bufsize); + BootParamHead *head = static_cast(buffer); + head->MaxLen = bufsize; + head->LdmemLen = 0x200000; + head->InputNum = i_num; + head->OutputNum = o_num; + head->ParamNum = p_num; + } + ~HrtBootParam() { + if (buffer != nullptr) { + free(buffer); + } + } + std::vector dyninfo; + uint32_t get_maxlen(); + void *get_bootpmbuffer(); + BootParamHead *get_headptr(); + BootParamDyninfo *get_inputptr(uint32_t index); + BootParamDyninfo *get_outputptr(uint32_t index); + BootParamDyninfo *get_paramptr(uint32_t index); + void set_dev_cache(uint64_t dev_addr, uint64_t size); + void set_dev_cache_mem_addr(uint64_t dev_addr, uint64_t size); + void set_dev_dyndata(uint64_t dev_addr, uint32_t size); + void set_dev_dyndata_mem_addr(uint64_t dev_addr, uint32_t size); + void set_dev_input(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_input_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_input_tensor(uint32_t idx, + std::shared_ptr tensor); + void set_dev_output(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_output_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_param(uint32_t idx, uint64_t dev_addr, uint64_t size); + void set_dev_param_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); + std::shared_ptr get_dev_output_tensor_after_run(uint32_t idx); + +private: + uint32_t i_num; + uint32_t o_num; + uint32_t p_num; + void *buffer; +}; +/* 启动参数end */ + +/* compiler后生成的存储elf和param地址的对象 */ +class HostParamElem { +public: + HostParamElem() : dataPtr(nullptr), size(0) {} + ~HostParamElem(); + // 模拟器:从文件中加载一个bin + HostParamElem(const std::string &filepath); + + uint8_t *loadBinaryFile(const std::string filepath, uint64_t &fsize); + uint8_t *dataPtr; // host + uint64_t size; // byte +}; + +class ChipModelInfo { +public: + ChipModelInfo(); + ChipModelInfo(uint32_t id); + ~ChipModelInfo(); + + uint32_t getChipId() { return chip_id; } + // support multi chip + std::vector> elfs; // 编译出的elf文件 + std::vector> bins; // 编译出的bin文件 + std::vector> params; + +private: + uint32_t chip_id; +}; + +/* + * compiler后生成的模型对象,在launch的时候会将elf/bin的指针传入soc的接口, + * (PCIE搬运时,如果空间不连续会触发多次搬运,因此交由SOC组装连续空间。) + */ +class TsmModel { +public: + TsmModel(); // org_model + ~TsmModel(); + TsmModel(const std::string &filepath); + + std::vector> chip_infos; + THREAD_PROC_FUNC proc_func; + std::string case_name; + std::string case_dir; + std::shared_ptr so_list[MAX_MODEL_NUM] + [TILE_MAX_NUM]; // 编译出的so文件 + std::string module_name; + struct txmodel *model[MAX_MODEL_NUM]; +}; + +typedef struct TsmDevice { + char res_path[128]; + uint32_t chip_id; + uint32_t tile_num = 16; + void *soc_device; +} TsmDevice_t; + +class TsmTensorData { +public: + TsmTensorData() : host_addr(0), device_addr(0), length(0) {} + ~TsmTensorData(){}; + + TsmHostPtr host_addr; + TsmDevicePtr device_addr; + uint64_t length; + uint32_t data_type; + Tensor_Type tensor_type; +}; + +typedef void *tsmStream_t; +typedef void *tsmEvent_t; +typedef struct txcclComm *txcclComm_t; +typedef enum { txcclDataDefault = 0 } txcclDataType_t; // 预留,待讨论 + +enum device_status { + FULLGOOD = 0, + PARTIALGOOD = 1, +}; + +constexpr uint32_t PARTIALGOOD_NUM = 8; +constexpr uint32_t FULLGOOD_NUM = 16; + +struct CardComputeInfo { + uint32_t card_id; + enum device_status device_status; + uint32_t all_tile_num; + double all_tile_compute; + uint32_t left_tile_num; + double left_tile_compute; +}; + +struct TsmDeviceInfo { + uint32_t card_num; + uint32_t card_x; + uint32_t card_y; + CardComputeInfo card_compute_info[CHIP_MAX_NUM]; +}; + +int32_t readDataFromFile(uint8_t *buffer, std::string file, uint32_t size); +uint8_t *read_file_data(std::string file, uint64_t &size); + +std::shared_ptr +get_multi_card_common_info_from_file(std::string file); +std::string get_docker_verison(); +TSM_RETCODE set_multi_graph(TsmModel *&kmodel, + std::shared_ptr &hostboot, + const TsmDevicePtr &dev_dyn_mods_ptr, + const TsmDevicePtr &dev_tlv_ptr, + TsmDevicePtr ext_ptr); +#endif /* __HOST_RUNTIME_COM_H__ */ diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h index 6c0fdd07d..ce726beca 100644 --- a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h +++ b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h @@ -29,34 +29,52 @@ TSM_RETCODE TsmInitRuntime(void); TSM_RETCODE TsmDeInitRuntime(void); TSM_RETCODE TsmDeInitRuntimeLegacy(void); -TSM_RETCODE TsmSetDevice(uint32_t first_phy_id, uint32_t card_x, uint32_t card_y, std::vector &devs); -TSM_RETCODE TsmSetDeviceOld(uint32_t chip_id, TsmDevice *dev); /* 该接口为提供给MLIR的过度版本,其他组件不要调用 */ +TSM_RETCODE TsmSetDevice(uint32_t first_phy_id, uint32_t card_x, + uint32_t card_y, std::vector &devs); +TSM_RETCODE TsmSetDeviceOld( + uint32_t chip_id, + TsmDevice *dev); /* 该接口为提供给MLIR的过度版本,其他组件不要调用 */ TSM_RETCODE TsmDeviceMalloc(TsmDevice *dev, TsmDevicePtr &ptr, uint64_t size); TSM_RETCODE TsmDeviceMemset(TsmDevicePtr &ptr, uint32_t ch, uint64_t size); TSM_RETCODE TsmDeviceFree(TsmDevicePtr ptr); TSM_RETCODE TsmDeviceSynchronize(TsmDevice *dev); TSM_RETCODE TsmInitDevice(TsmDevice *dev); -TSM_RETCODE TsmCompile(std::vector devs, TsmModel &kmodel, std::string option, CompileOption compl_op); -TSM_RETCODE TsmCompileMultiGraph(std::vector devs, TsmModel &kmodel, std::string option, +TSM_RETCODE TsmCompile(std::vector devs, TsmModel &kmodel, + std::string option, CompileOption compl_op); +TSM_RETCODE TsmCompileMultiGraph(std::vector devs, + TsmModel &kmodel, std::string option, CompileOption compl_op); TSM_RETCODE TsmLaunch(TsmDevice *dev, TsmModel &kmodel); -TSM_RETCODE TsmLoadKernel(TsmDevice *dev, std::vector &kmodel_vec, char *module_symbol); -TSM_RETCODE TsmUnloadKernel(TsmDevice *dev, std::vector &kmodel_vec); +TSM_RETCODE TsmLoadKernel(TsmDevice *dev, std::vector &kmodel_vec, + char *module_symbol); +TSM_RETCODE TsmUnloadKernel(TsmDevice *dev, + std::vector &kmodel_vec); TSM_RETCODE TsmRun(TsmDevice *dev, TsmDevicePtr bootpm_dev); -TSM_RETCODE TsmAsyncRun(tsmStream_t stream, TsmDevice *dev, TsmDevicePtr bootpm_dev); +TSM_RETCODE TsmAsyncRun(tsmStream_t stream, TsmDevice *dev, + TsmDevicePtr bootpm_dev); TSM_RETCODE TsmSetTerminate(TsmDevice *dev, tsmStream_t stream = nullptr); TSM_RETCODE TsmGetDeviceInfo(TsmDeviceInfo *info); TSM_RETCODE TsmTerminate(TsmDevice *dev, TsmDevicePtr bootpm_dev); -TSM_RETCODE TsmMemcpyH2D(TsmDevicePtr dst, const void *src, uint64_t byte_count); -TSM_RETCODE TsmMemcpyD2H(const void *dst, TsmDevicePtr src, uint64_t byte_count); -TSM_RETCODE TsmMemcpyOffsetH2D(TsmDevicePtr dst, const void *src, uint64_t offset, uint64_t byte_count); -TSM_RETCODE TsmMemcpyOffsetD2H(const void *dst, TsmDevicePtr src, uint64_t offset, uint64_t byte_count); -TSM_RETCODE TsmMemcpyD2D(const void *dst, TsmDevice *dst_dev, const void *src, TsmDevice *src_dev, uint64_t byte_count); -TSM_RETCODE TsmSend(const void* sendbuff, size_t count, txcclDataType_t datatype, TsmDevice *dev, int peer, txcclComm_t comm, tsmStream_t stream); -TSM_RETCODE TsmRecv(void* recvbuff, size_t count, txcclDataType_t datatype, TsmDevice *dev, int peer, txcclComm_t comm, tsmStream_t stream); +TSM_RETCODE TsmMemcpyH2D(TsmDevicePtr dst, const void *src, + uint64_t byte_count); +TSM_RETCODE TsmMemcpyD2H(const void *dst, TsmDevicePtr src, + uint64_t byte_count); +TSM_RETCODE TsmMemcpyOffsetH2D(TsmDevicePtr dst, const void *src, + uint64_t offset, uint64_t byte_count); +TSM_RETCODE TsmMemcpyOffsetD2H(const void *dst, TsmDevicePtr src, + uint64_t offset, uint64_t byte_count); +TSM_RETCODE TsmMemcpyD2D(const void *dst, TsmDevice *dst_dev, const void *src, + TsmDevice *src_dev, uint64_t byte_count); +TSM_RETCODE TsmSend(const void *sendbuff, size_t count, + txcclDataType_t datatype, TsmDevice *dev, int peer, + txcclComm_t comm, tsmStream_t stream); +TSM_RETCODE TsmRecv(void *recvbuff, size_t count, txcclDataType_t datatype, + TsmDevice *dev, int peer, txcclComm_t comm, + tsmStream_t stream); TSM_RETCODE TsmResetDevice(TsmDevice *dev); TSM_RETCODE TsmReleaseDevice(TsmDevice *dev); -TSM_RETCODE TsmMemGetInfo(TsmDevicePtr ptr, uint32_t &card_id, uint64_t &addr, uint64_t &size); +TSM_RETCODE TsmMemGetInfo(TsmDevicePtr ptr, uint32_t &card_id, uint64_t &addr, + uint64_t &size); TSM_RETCODE TsmEventCreate(tsmEvent_t *pEvent); TSM_RETCODE TsmEventDestroy(tsmEvent_t event); TSM_RETCODE TsmEventRecord(tsmEvent_t event, tsmStream_t stream); @@ -64,12 +82,16 @@ TSM_RETCODE TsmEventWait(tsmEvent_t event, tsmStream_t stream); TSM_RETCODE TsmStreamCreate(tsmStream_t *pStream, TsmDevice *dev); TSM_RETCODE TsmStreamSynchronize(tsmStream_t stream); TSM_RETCODE TsmStreamDestroy(tsmStream_t stream); -TSM_RETCODE TsmDeviceSerialize(const TsmDevice *const &dev, void *&buffer, size_t &size); +TSM_RETCODE TsmDeviceSerialize(const TsmDevice *const &dev, void *&buffer, + size_t &size); TSM_RETCODE TsmDeviceDeSerialize(TsmDevice *&dev, const void *const &buffer); TSM_RETCODE TsmSetMonitorInfo(TsmDevice *dev); -TSM_RETCODE TsmProcessProfData(TsmDevice *dev, TsmProfAction prof_action, uint16_t prof_type); -TSM_RETCODE TsmHostH2D(TsmDevice *dev, uint64_t input_host_addr, uint64_t input_size, int32_t index); -TSM_RETCODE TsmHostFlush(TsmDevice *dev, uint64_t boot_param_ptr, uint8_t *host_buffer, size_t size); +TSM_RETCODE TsmProcessProfData(TsmDevice *dev, TsmProfAction prof_action, + uint16_t prof_type); +TSM_RETCODE TsmHostH2D(TsmDevice *dev, uint64_t input_host_addr, + uint64_t input_size, int32_t index); +TSM_RETCODE TsmHostFlush(TsmDevice *dev, uint64_t boot_param_ptr, + uint8_t *host_buffer, size_t size); TSM_RETCODE TsmSetRankSize(uint32_t x_size, uint32_t y_size); TSM_RETCODE TsmSetRankId(uint32_t x, uint32_t y); TSM_RETCODE TsmGetPhyRankId(uint32_t *x, uint32_t *y); @@ -81,7 +103,8 @@ TSM_RETCODE TsmGetDeviceNum(uint32_t &dev_num); /* * 为保持Host日志格式统一,Runtime提供了统一日志接口,各组件按以下方式使用: - * #define rt_log(level, format, ...) tsm_log(__FILE__, __func__, __LINE__, TSM_RUNTIME, level, format, ##__VA_ARGS__) + * #define rt_log(level, format, ...) tsm_log(__FILE__, __func__, __LINE__, + * TSM_RUNTIME, level, format, ##__VA_ARGS__) * * void func() { * rt_log(LOG_DEBUG, "....\n"); @@ -89,10 +112,11 @@ TSM_RETCODE TsmGetDeviceNum(uint32_t &dev_num); * rt_log(LOG_WARNING, "....\n"); * rt_log(LOG_ERROR, "....\n"); * } - * 默认日志级别为INFO,通过设置 HOST_LOG_LEVEL 更改日志级别,一般就设置成INFO和DEBUG。 - * 注意: + * 默认日志级别为INFO,通过设置 HOST_LOG_LEVEL + * 更改日志级别,一般就设置成INFO和DEBUG。 注意: * 其中rt_log为各组件定制名称,切勿重复,TSM_RUNTIME表示模块ID,各模块到hrt_common.h找到自己的宏,没有的可以联系runtime来增加。 */ -void tsm_log(const char *file_name, const char *func_name, uint32_t line_number, TsmModuleType module_type, - HostLogLevel level, const char *format, ...); -#endif \ No newline at end of file +void tsm_log(const char *file_name, const char *func_name, uint32_t line_number, + TsmModuleType module_type, HostLogLevel level, const char *format, + ...); +#endif diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmax.c b/third_party/tsingmicro/crt/lib/Tx81/argmax.c index 412e9fa24..a982f8ff9 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/argmax.c +++ b/third_party/tsingmicro/crt/lib/Tx81/argmax.c @@ -14,9 +14,16 @@ void __ArgMax(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->ArgMax(&inst, (uint64_t) src, elem_count, (Data_Format) fmt); + cmd->ArgMax(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/argmin.c b/third_party/tsingmicro/crt/lib/Tx81/argmin.c index e79223b5a..856854d3b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/argmin.c +++ b/third_party/tsingmicro/crt/lib/Tx81/argmin.c @@ -14,9 +14,16 @@ void __ArgMin(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->ArgMin(&inst, (uint64_t) src, elem_count, (Data_Format) fmt); + cmd->ArgMin(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/arith.c b/third_party/tsingmicro/crt/lib/Tx81/arith.c index 4fc7120ee..9eb7ed3b8 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/arith.c +++ b/third_party/tsingmicro/crt/lib/Tx81/arith.c @@ -11,14 +11,20 @@ #include "tx81.h" -void __AddVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, - uint32_t elem_count, RND_MODE round, uint16_t fmt) { +void __AddVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { // Create command buffer. TsmArith *cmd = TsmNewArith(); - TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->AddVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, - round, (Data_Format) fmt); + round, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); @@ -27,14 +33,20 @@ void __AddVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, TsmDeleteArith(cmd); } -void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, - uint32_t elem_count, RND_MODE round, uint16_t fmt) { +void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { // Create command buffer. TsmArith *cmd = TsmNewArith(); - TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->SubVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, - round, (Data_Format) fmt); + round, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); @@ -43,14 +55,20 @@ void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, TsmDeleteArith(cmd); } -void __MulVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, - uint32_t elem_count, RND_MODE round, uint16_t fmt) { +void __MulVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { // Create command buffer. TsmArith *cmd = TsmNewArith(); - TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->MulVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, - round, (Data_Format) fmt); + round, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); @@ -59,12 +77,17 @@ void __MulVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, TsmDeleteArith(cmd); } - -void __DivVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, - uint32_t elem_count, RND_MODE round, uint16_t fmt) { +void __DivVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE round, uint16_t fmt) { // Create command buffer. TsmArith *cmd = TsmNewArith(); - TsmArithInstr inst = {I_CGRA, {0,}, {0,}}; + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->DivVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, round, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c index 68df7b31e..eef93082f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c @@ -13,9 +13,15 @@ void __BF16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->BF16_FP16(&inst, (uint64_t) src, (uint64_t) dst, elem_count); + cmd->BF16_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c index c3063cdcf..957fb60c4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp32.c @@ -13,7 +13,13 @@ void __BF16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->BF16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c index a42e3614d..930c3f395 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __BF16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->BF16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c index 4af892a73..b9b5b38da 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __BF16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->BF16_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c index 4286f22fa..194c104b4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c @@ -13,7 +13,13 @@ void __BF16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->BF16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c index 9220132cc..e5a60ca55 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c @@ -13,7 +13,13 @@ void __BF16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->BF16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c index 23ff84532..43ddf791c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c @@ -12,17 +12,24 @@ #include "tx81.h" void __Bilinear(uint64_t *src, uint64_t *dst, uint16_t src_n, uint16_t src_h, - uint16_t src_w, uint16_t src_c, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_w, uint16_t src_c, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Bilinear(&inst, (uint64_t) src, (uint64_t) dst, shape1, shape2, - (src_w - 1) / (dst_w - 1), (src_h - 1) / (dst_h - 1), - (Data_Format) fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Bilinear(&inst, (uint64_t)src, (uint64_t)dst, shape1, shape2, + (src_w - 1) / (dst_w - 1), (src_h - 1) / (dst_h - 1), + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c index 7b818b1df..0a9344213 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c @@ -14,9 +14,17 @@ void __Bit2Fp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format) fmt); + cmd->Bit2Fp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/common.c b/third_party/tsingmicro/crt/lib/Tx81/common.c index 680e5a367..c1b6085d9 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/common.c +++ b/third_party/tsingmicro/crt/lib/Tx81/common.c @@ -12,14 +12,8 @@ #include "tx81.h" // WORKAROUND for undefined symbols in libkcorert.a -int main(int argc, char** argv) { - return 0; -} +int main(int argc, char **argv) { return 0; } -int get_app_version() { - return 1; -} +int get_app_version() { return 1; } -int nvram_get_val() { - return 1; -} \ No newline at end of file +int nvram_get_val() { return 1; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/concat.c b/third_party/tsingmicro/crt/lib/Tx81/concat.c index a7cc0da55..8bd489bff 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/concat.c +++ b/third_party/tsingmicro/crt/lib/Tx81/concat.c @@ -12,19 +12,25 @@ #include "tx81.h" void __Concat(uint64_t *src1, uint16_t src1_n, uint16_t src1_h, uint16_t src1_w, - uint16_t src1_c, uint64_t *src2, uint16_t src2_n, uint16_t src2_h, - uint16_t src2_w, uint16_t src2_c, uint64_t *dst, uint16_t dst_n, - uint16_t dst_h, uint16_t dst_w, uint16_t dst_c, uint32_t dim, - uint16_t fmt) { + uint16_t src1_c, uint64_t *src2, uint16_t src2_n, uint16_t src2_h, + uint16_t src2_w, uint16_t src2_c, uint64_t *dst, uint16_t dst_n, + uint16_t dst_h, uint16_t dst_w, uint16_t dst_c, uint32_t dim, + uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src1_n, src1_h, src1_w, src1_c }; - Data_Shape shape2 = { src2_n, src2_h, src2_w, src2_c }; - Data_Shape shape3 = { dst_n, dst_h, dst_w, dst_c }; + Data_Shape shape1 = {src1_n, src1_h, src1_w, src1_c}; + Data_Shape shape2 = {src2_n, src2_h, src2_w, src2_c}; + Data_Shape shape3 = {dst_n, dst_h, dst_w, dst_c}; cmd->Concat(&inst, (uint64_t)src1, shape1, (uint64_t)src2, shape2, - (uint64_t) dst, shape3, dim, (Data_Format) fmt); + (uint64_t)dst, shape3, dim, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/conv.c b/third_party/tsingmicro/crt/lib/Tx81/conv.c index 3f05b2ee8..d4f7dcae2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/conv.c +++ b/third_party/tsingmicro/crt/lib/Tx81/conv.c @@ -12,17 +12,22 @@ #include "tx81.h" // The arguments list is aligned with TsmConv in Tx81Ops.td -void __Conv(int64_t opType, int64_t* srcAct, int64_t* srcActDims, int64_t* weight, - int64_t* weightDims, bool enBias, int64_t* bias, bool enNegScale, - int64_t* negScale, bool enPosScale, int64_t* posScale, bool enSparse, - int64_t* sparse, bool enPsum, int64_t* psum, int64_t* pads, - int64_t* unpads, int64_t* strides, int64_t* dilations, - bool enLeakyRelu, int64_t srcActFmt, int64_t weightFmt, int64_t dstFmt, - int64_t* dst, int64_t* dstDims) -{ +void __Conv(int64_t opType, int64_t *srcAct, int64_t *srcActDims, + int64_t *weight, int64_t *weightDims, bool enBias, int64_t *bias, + bool enNegScale, int64_t *negScale, bool enPosScale, + int64_t *posScale, bool enSparse, int64_t *sparse, bool enPsum, + int64_t *psum, int64_t *pads, int64_t *unpads, int64_t *strides, + int64_t *dilations, bool enLeakyRelu, int64_t srcActFmt, + int64_t weightFmt, int64_t dstFmt, int64_t *dst, int64_t *dstDims) { // Create convolution command buffer. TsmConv *conv = TsmNewConv(); - TsmNeInstr inst = {I_NEUR, {0,}, {0,}}; + TsmNeInstr inst = {I_NEUR, + { + 0, + }, + { + 0, + }}; // Convert to nhwc format Data_Shape shape = {(uint16_t)srcActDims[0], (uint16_t)srcActDims[1], @@ -34,16 +39,16 @@ void __Conv(int64_t opType, int64_t* srcAct, int64_t* srcActDims, int64_t* weigh Data_Shape dstShape = {(uint16_t)dstDims[0], (uint16_t)dstDims[1], (uint16_t)dstDims[2], (uint16_t)dstDims[3]}; - conv->AddInput(&inst, (int64_t) srcAct, shape, (Data_Format)srcActFmt); - conv->AddWeight(&inst, (uint64_t) weight, wshape, (Data_Format)weightFmt); - conv->AddBias(&inst, enBias, (uint64_t) bias); - conv->AddOutput(&inst, (uint64_t) dst, dstShape, (Data_Format)dstFmt); + conv->AddInput(&inst, (int64_t)srcAct, shape, (Data_Format)srcActFmt); + conv->AddWeight(&inst, (uint64_t)weight, wshape, (Data_Format)weightFmt); + conv->AddBias(&inst, enBias, (uint64_t)bias); + conv->AddOutput(&inst, (uint64_t)dst, dstShape, (Data_Format)dstFmt); conv->SetOpType(&inst, opType); - conv->SetNegativeAxisScale(&inst, enNegScale, (uint64_t) negScale); - conv->SetPositiveAxisScale(&inst, enPosScale, (uint64_t) posScale); - conv->SetSparse(&inst, enSparse, (uint64_t) sparse); + conv->SetNegativeAxisScale(&inst, enNegScale, (uint64_t)negScale); + conv->SetPositiveAxisScale(&inst, enPosScale, (uint64_t)posScale); + conv->SetSparse(&inst, enSparse, (uint64_t)sparse); // FIXME: Should we have psum format instead? - conv->SetPsum(&inst, enPsum, (uint64_t) psum, (Data_Format) dstFmt); + conv->SetPsum(&inst, enPsum, (uint64_t)psum, (Data_Format)dstFmt); conv->SetPads(&inst, pads[0], pads[1], pads[2], pads[3]); conv->SetUnPads(&inst, unpads[0], unpads[1], unpads[2], unpads[3]); conv->SetKernelStrides(&inst, strides[0], strides[1], strides[2], strides[3]); diff --git a/third_party/tsingmicro/crt/lib/Tx81/cos.c b/third_party/tsingmicro/crt/lib/Tx81/cos.c index 23c965a2e..0ea6f096b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/cos.c +++ b/third_party/tsingmicro/crt/lib/Tx81/cos.c @@ -14,9 +14,15 @@ void __Cos(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->Cos(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format) fmt); + cmd->Cos(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/count.c b/third_party/tsingmicro/crt/lib/Tx81/count.c index 9582070f7..855ed95e7 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/count.c +++ b/third_party/tsingmicro/crt/lib/Tx81/count.c @@ -12,13 +12,20 @@ #include "tx81.h" void __Count(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt, - uint64_t *p_wb_data0, uint64_t *p_wb_data1) { + uint64_t *p_wb_data0, uint64_t *p_wb_data1) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->Count(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt, p_wb_data0, - p_wb_data1); + cmd->Count(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt, + p_wb_data0, p_wb_data1); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/exp.c b/third_party/tsingmicro/crt/lib/Tx81/exp.c index 472df41b0..37b52b1f0 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/exp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/exp.c @@ -14,7 +14,13 @@ void __Exp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Exp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/explp.c b/third_party/tsingmicro/crt/lib/Tx81/explp.c index c88d2ae68..e917ae61e 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/explp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/explp.c @@ -14,7 +14,13 @@ void __Explp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Explp(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c index a5ba62869..2fe6b4d6d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c index 2c6732f0f..8493d00ee 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_fp32.c @@ -13,7 +13,13 @@ void __FP16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c index 5a161b4e9..f6a34594b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c @@ -14,7 +14,13 @@ void __FP16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP16_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c index cd64ac37a..9cd30e5ef 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c index 4d1356aa2..a099d8948 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP16_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c index d2d3163ad..f36fbe943 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c @@ -13,7 +13,13 @@ void __FP16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c index 0cdf0c995..c18f95f7c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c index ffd647e35..0c27fba95 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_fp16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c index 95201fa00..5ec2e93c4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c index bdaa4a9ad..9d8e0622c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c index b82017ae0..23d4306f4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c index 8eec01b48..ecf2436ae 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_tf32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __FP32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->FP32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c index ef4e72c95..032290c1f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gatherscatter.c @@ -11,19 +11,26 @@ #include "tx81.h" -void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, uint32_t src_s0, - uint32_t src_i0, uint32_t src_s1, uint32_t src_i1, - uint32_t src_s2, uint32_t src_i2, uint32_t dst_s0, - uint32_t dst_i0, uint32_t dst_s1, uint32_t dst_i1, - uint32_t dst_s2, uint32_t dst_i2) { +void __GatherScatter(uint64_t *src, uint64_t *dst, uint32_t size, + uint32_t src_s0, uint32_t src_i0, uint32_t src_s1, + uint32_t src_i1, uint32_t src_s2, uint32_t src_i2, + uint32_t dst_s0, uint32_t dst_i0, uint32_t dst_s1, + uint32_t dst_i1, uint32_t dst_s2, uint32_t dst_i2) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; St_StrideIteration src_si = {src_s0, src_i0, src_s1, src_i1, src_s2, src_i2}; St_StrideIteration dst_si = {dst_s0, dst_i0, dst_s1, dst_i1, dst_s2, dst_i2}; - cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, size, &src_si, &dst_si); + cmd->GatherScatter(&inst, (uint64_t)src, (uint64_t)dst, size, &src_si, + &dst_si); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/gemm.c b/third_party/tsingmicro/crt/lib/Tx81/gemm.c index 09b8c18ff..a0f14d83a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gemm.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gemm.c @@ -12,28 +12,34 @@ #include "tx81.h" // The arguments list is aligned with TsmConv in Tx81Ops.td -void __Gemm(int64_t* srcA, int64_t *srcB, int64_t * srcBias, int64_t *zeros, - int64_t *dims, bool enPsum, int64_t *psum, bool enTransA, bool enTransB, - int64_t batchSizeA, int64_t batchSizeB, bool enLeakyRelu, bool enBias, - bool enNegScale, int64_t *negScale, bool enPosScale, int64_t *posScale, - int64_t srcFmt, int64_t dstFmt, int64_t* dst) -{ +void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *zeros, + int64_t *dims, bool enPsum, int64_t *psum, bool enTransA, + bool enTransB, int64_t batchSizeA, int64_t batchSizeB, + bool enLeakyRelu, bool enBias, bool enNegScale, int64_t *negScale, + bool enPosScale, int64_t *posScale, int64_t srcFmt, int64_t dstFmt, + int64_t *dst) { // Create gemm command buffer. TsmGemm *gemm = TsmNewGemm(); - TsmNeInstr inst = {I_NEUR, {0,}, {0,}}; + TsmNeInstr inst = {I_NEUR, + { + 0, + }, + { + 0, + }}; - gemm->AddInput(&inst, (uint64_t) srcA, (uint64_t) srcB, (Data_Format) srcFmt); - gemm->ConfigMKN(&inst, (uint32_t) dims[0], (uint32_t) dims[1], - (uint32_t) dims[2]); - gemm->AddOutput(&inst, (uint64_t) dst, (Data_Format) dstFmt); - gemm->SetPsum(&inst, enPsum, (uint64_t) psum, (Data_Format) dstFmt); - gemm->SetTransflag(&inst, (uint8_t) enTransA, (uint8_t) enTransB); + gemm->AddInput(&inst, (uint64_t)srcA, (uint64_t)srcB, (Data_Format)srcFmt); + gemm->ConfigMKN(&inst, (uint32_t)dims[0], (uint32_t)dims[1], + (uint32_t)dims[2]); + gemm->AddOutput(&inst, (uint64_t)dst, (Data_Format)dstFmt); + gemm->SetPsum(&inst, enPsum, (uint64_t)psum, (Data_Format)dstFmt); + gemm->SetTransflag(&inst, (uint8_t)enTransA, (uint8_t)enTransB); // TODO: // gemm->SetQuant(); - gemm->ConfigBatch(&inst, (uint32_t) batchSizeA, (uint32_t) batchSizeB); - gemm->AddBias(&inst, enBias, (uint64_t) srcBias); - gemm->SetNegativeAxisScale(&inst, enNegScale, (uint64_t) negScale); - gemm->SetPositiveAxisScale(&inst, enPosScale, (uint64_t) posScale); + gemm->ConfigBatch(&inst, (uint32_t)batchSizeA, (uint32_t)batchSizeB); + gemm->AddBias(&inst, enBias, (uint64_t)srcBias); + gemm->SetNegativeAxisScale(&inst, enNegScale, (uint64_t)negScale); + gemm->SetPositiveAxisScale(&inst, enPosScale, (uint64_t)posScale); if (enLeakyRelu) gemm->EnableLeakyRelu(&inst); else diff --git a/third_party/tsingmicro/crt/lib/Tx81/img2col.c b/third_party/tsingmicro/crt/lib/Tx81/img2col.c index 3d2bba633..a578e1351 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/img2col.c +++ b/third_party/tsingmicro/crt/lib/Tx81/img2col.c @@ -12,20 +12,26 @@ #include "tx81.h" void __Img2col(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint64_t src_elem_num, - uint64_t dst_elem_num, uint16_t swr_n, uint16_t swr_h, - uint16_t swr_w, uint16_t swr_c, uint16_t pdr_n, uint16_t pdr_h, - uint16_t pdr_w, uint16_t pdr_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint64_t src_elem_num, + uint64_t dst_elem_num, uint16_t swr_n, uint16_t swr_h, + uint16_t swr_w, uint16_t swr_c, uint16_t pdr_n, uint16_t pdr_h, + uint16_t pdr_w, uint16_t pdr_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; Data_Shape shape1 = {src_n, src_h, src_w, src_c}; Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; Data_Shape shape3 = {swr_n, swr_h, swr_w, swr_c}; Data_Shape shape4 = {pdr_n, pdr_h, pdr_w, pdr_c}; - cmd->Img2col(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, + cmd->Img2col(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, src_elem_num, dst_elem_num, shape3, shape4, (Data_Format)fmt); // Dispatch the command to accelerator diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c index f06d8bec9..213681aad 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT16_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c index 0486a4834..a23297033 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c @@ -13,7 +13,13 @@ void __INT16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT16_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c index 670f3a9f5..9e5ba8e3f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT16_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c index 61022964b..1a08f227f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT16_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c index 261140c0f..5b9949719 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c index 30169b1ed..e9c9f14ee 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c @@ -12,10 +12,16 @@ #include "tx81.h" void __INT32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c index b56cb3821..6fb7778f4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c index f0dc2c69a..6c65087da 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __INT32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT32_TF32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c index 5f2253093..bce9dfa27 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c @@ -10,10 +10,17 @@ #include "tx81.h" -void __INT8_BF16(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { +void __INT8_BF16(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT8_BF16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c index 9166fa050..94061ebc4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c @@ -10,10 +10,17 @@ #include "tx81.h" -void __INT8_FP16(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { +void __INT8_FP16(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT8_FP16(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c index 853ddd14a..b2be9df3c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp32.c @@ -10,10 +10,17 @@ #include "tx81.h" -void __INT8_FP32(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { +void __INT8_FP32(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT8_FP32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c index 7fe060ab4..3fb5fcfae 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c @@ -10,10 +10,17 @@ #include "tx81.h" -void __INT8_TF32(uint64_t *src, uint32_t zp, uint64_t *dst, uint32_t elem_count) { +void __INT8_TF32(uint64_t *src, uint32_t zp, uint64_t *dst, + uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->INT8_TF32(&inst, (uint64_t)src, zp, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c index 2e59054b2..c1cdb81f0 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c @@ -12,12 +12,19 @@ #include "tx81.h" void __Leakyrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, - uint16_t fmt) { + uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->Leakyrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + cmd->Leakyrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/ln.c b/third_party/tsingmicro/crt/lib/Tx81/ln.c index 41c528316..01776e243 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/ln.c +++ b/third_party/tsingmicro/crt/lib/Tx81/ln.c @@ -14,7 +14,13 @@ void __Ln(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Ln(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/log2.c b/third_party/tsingmicro/crt/lib/Tx81/log2.c index b012fdc11..8dcdfc3e8 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/log2.c +++ b/third_party/tsingmicro/crt/lib/Tx81/log2.c @@ -14,7 +14,13 @@ void __Log2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Log2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut16.c b/third_party/tsingmicro/crt/lib/Tx81/lut16.c index eea3a71d8..d4e8bea10 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/lut16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/lut16.c @@ -12,12 +12,20 @@ #include "tx81.h" void __Lut16(uint64_t *src, uint64_t *dst, uint64_t *lut16, - uint32_t src_elem_count, uint32_t lut_elem_count) { + uint32_t src_elem_count, uint32_t lut_elem_count) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->Lut16(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut16, src_elem_count, lut_elem_count); + cmd->Lut16(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut16, + src_elem_count, lut_elem_count); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut32.c b/third_party/tsingmicro/crt/lib/Tx81/lut32.c index c3ca38f23..16f6df9ba 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/lut32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/lut32.c @@ -12,12 +12,20 @@ #include "tx81.h" void __Lut32(uint64_t *src, uint64_t *dst, uint64_t *lut32, - uint32_t src_elem_count, uint32_t lut_elem_count) { + uint32_t src_elem_count, uint32_t lut_elem_count) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; - cmd->Lut32(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut32, src_elem_count, lut_elem_count); + cmd->Lut32(&inst, (uint64_t)src, (uint64_t)dst, (uint64_t)lut32, + src_elem_count, lut_elem_count); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c index cbeef4726..05d7989e7 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c @@ -12,9 +12,15 @@ #include "tx81.h" void __MaskMove(uint64_t *src, uint64_t *target, uint32_t elem_count, - uint64_t * mask, int32_t fmt) { + uint64_t *mask, int32_t fmt) { TsmMaskDataMove *move = TsmNewMaskDataMove(); - TsmMaskDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmMaskDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; move->MaskMove(&inst, (uint64_t)src, (uint64_t)mask, (uint64_t)target, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/memset.c b/third_party/tsingmicro/crt/lib/Tx81/memset.c index 7c412e496..86edde788 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/memset.c +++ b/third_party/tsingmicro/crt/lib/Tx81/memset.c @@ -12,11 +12,17 @@ #include "tx81.h" void __Memset(uint64_t *dst, uint32_t value, uint32_t elem_count, uint32_t s0, - uint32_t i0, uint32_t s1, uint32_t i1, uint32_t s2, uint32_t i2, - uint16_t fmt) { + uint32_t i0, uint32_t s1, uint32_t i1, uint32_t s2, uint32_t i2, + uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; St_StrideIteration si = {s0, i0, s1, i1, s2, i2}; cmd->Memset(&inst, (uint64_t)dst, value, elem_count, &si, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/mirror.c b/third_party/tsingmicro/crt/lib/Tx81/mirror.c index db113e5fc..543fec808 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mirror.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mirror.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Mirror(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Mirror(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Mirror(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c index cfca054f7..ed916cb60 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Nchw2nhwc(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Nchw2nhwc(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Nchw2nhwc(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c index e9ccdcf0e..932b71599 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Nhwc2nchw(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Nhwc2nchw(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Nhwc2nchw(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/pad.c b/third_party/tsingmicro/crt/lib/Tx81/pad.c index e8d3caf72..3ccde1221 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/pad.c +++ b/third_party/tsingmicro/crt/lib/Tx81/pad.c @@ -12,18 +12,24 @@ #include "tx81.h" void __Pad(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t pad_n, uint16_t pad_h, - uint16_t pad_w, uint16_t pad_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t pad_n, uint16_t pad_h, + uint16_t pad_w, uint16_t pad_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - Data_Shape shape3 = { pad_n, pad_h, pad_w, pad_c }; - cmd->Pad(&inst, (uint64_t) src, shape1, (uint64_t) dst, - shape2, shape3, (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + Data_Shape shape3 = {pad_n, pad_h, pad_w, pad_c}; + cmd->Pad(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, shape3, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/pow2.c b/third_party/tsingmicro/crt/lib/Tx81/pow2.c index 060edf08c..9ed3fa0ae 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/pow2.c +++ b/third_party/tsingmicro/crt/lib/Tx81/pow2.c @@ -14,7 +14,13 @@ void __Pow2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Pow2(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/randgen.c b/third_party/tsingmicro/crt/lib/Tx81/randgen.c index d390a2e6f..85382f17d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/randgen.c +++ b/third_party/tsingmicro/crt/lib/Tx81/randgen.c @@ -12,13 +12,20 @@ #include "tx81.h" void __RandGen(uint64_t *src0, uint64_t *src1, uint64_t *dst0, uint64_t *dst1, - uint64_t *dst2, uint32_t src_elem_num, uint16_t fmt) { + uint64_t *dst2, uint32_t src_elem_num, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); - TsmPeripheralInstr inst = {I_CGRA, {0,}, {0,}};; + TsmPeripheralInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + ; cmd->RandGen(&inst, *src0, *src1, *dst0, *dst1, *dst2, src_elem_num, - (Data_Format)fmt); + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/relu.c b/third_party/tsingmicro/crt/lib/Tx81/relu.c index 90bddef53..ccaf77ec7 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/relu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/relu.c @@ -14,7 +14,13 @@ void __Relu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Relu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c index b5d78ae0b..1b068458b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Rotate180(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Rotate180(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate180(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c index f82561830..15d84f28f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Rotate270(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Rotate270(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate270(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c index 9e8480470..15c87d429 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c @@ -12,15 +12,22 @@ #include "tx81.h" void __Rotate90(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Rotate90(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Rotate90(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c index 338f2852d..e9d67dfee 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c @@ -11,12 +11,20 @@ #include "tx81.h" -void __Satrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { +void __Satrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->Satrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + cmd->Satrelu(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c index 03761fcd6..d92a42e22 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c +++ b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c @@ -11,12 +11,20 @@ #include "tx81.h" -void __Sigmoid(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { +void __Sigmoid(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->Sigmoid(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + cmd->Sigmoid(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/sin.c b/third_party/tsingmicro/crt/lib/Tx81/sin.c index 065f57e85..7bb37c6d2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/sin.c +++ b/third_party/tsingmicro/crt/lib/Tx81/sin.c @@ -14,7 +14,13 @@ void __Sin(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmTranscendental *cmd = TsmNewTranscendental(); - TsmTranscendentalInstr inst = {I_CGRA, {0,}, {0,}}; + TsmTranscendentalInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Sin(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/softplus.c b/third_party/tsingmicro/crt/lib/Tx81/softplus.c index af1f1f0ee..d384cf701 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/softplus.c +++ b/third_party/tsingmicro/crt/lib/Tx81/softplus.c @@ -12,12 +12,20 @@ #include "tx81.h" -void __Softplus(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { +void __Softplus(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - cmd->Softplus(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + cmd->Softplus(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tanh.c b/third_party/tsingmicro/crt/lib/Tx81/tanh.c index aecc93431..e91cf229d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tanh.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tanh.c @@ -14,7 +14,13 @@ void __Tanh(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmActivation *cmd = TsmNewActivation(); - TsmActivationInstr inst = {I_CGRA, {0,}, {0,}}; + TsmActivationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->Tanh(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c index a2f023620..d141faa44 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c @@ -12,16 +12,22 @@ #include "tx81.h" void __TensorNorm(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->TensorNom(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->TensorNom(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c index b01999f77..6b9e2b5d3 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __TF32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_BF16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c index 7369e13d7..9e78a4bda 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c @@ -13,7 +13,13 @@ void __TF32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_FP16(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c index 40f8036f3..92550fb2a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c @@ -13,7 +13,13 @@ void __TF32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_FP32(&inst, (uint64_t)src, (uint64_t)dst, elem_count); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c index 1c0b07546..b6e2951c2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c @@ -11,10 +11,16 @@ #include "tx81.h" void __TF32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_INT16(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c index 5ca965e47..de1ae6725 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c @@ -11,10 +11,16 @@ #include "tx81.h" void __TF32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_INT32(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c index fec628c19..4c60fbf98 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c @@ -11,10 +11,16 @@ #include "tx81.h" void __TF32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, - RND_MODE round) { + RND_MODE round) { // Create command buffer. TsmConvert *cmd = TsmNewConvert(); - TsmConvertInstr inst = {I_CGRA, {0,}, {0,}}; + TsmConvertInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; cmd->TF32_INT8(&inst, (uint64_t)src, (uint64_t)dst, elem_count, round); diff --git a/third_party/tsingmicro/crt/lib/Tx81/transpose.c b/third_party/tsingmicro/crt/lib/Tx81/transpose.c index c9725c56c..54e2ee584 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/transpose.c +++ b/third_party/tsingmicro/crt/lib/Tx81/transpose.c @@ -12,16 +12,22 @@ #include "tx81.h" void __Transpose(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, - uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, - uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { + uint16_t src_c, uint64_t *dst, uint16_t dst_n, uint16_t dst_h, + uint16_t dst_w, uint16_t dst_c, uint16_t fmt) { // Create command buffer. TsmDataMove *cmd = TsmNewDataMove(); - TsmDataMoveInstr inst = {I_CGRA, {0,}, {0,}}; + TsmDataMoveInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; - Data_Shape shape1 = { src_n, src_h, src_w, src_c }; - Data_Shape shape2 = { dst_n, dst_h, dst_w, dst_c }; - cmd->Transpose(&inst, (uint64_t) src, shape1, (uint64_t) dst, shape2, - (Data_Format)fmt); + Data_Shape shape1 = {src_n, src_h, src_w, src_c}; + Data_Shape shape2 = {dst_n, dst_h, dst_w, dst_c}; + cmd->Transpose(&inst, (uint64_t)src, shape1, (uint64_t)dst, shape2, + (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp index 48e2afbf5..87e47027f 100644 --- a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.cpp @@ -16,7 +16,7 @@ #include "Msan.h" #ifndef _WIN32 -#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ defined(__DragonFly__) #include #else @@ -37,10 +37,7 @@ #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS namespace { -template -void stdSort(uint64_t n, V *p) { - std::sort(p, p + n); -} +template void stdSort(uint64_t n, V *p) { std::sort(p, p + n); } } // namespace diff --git a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h index 76b04145b..1e55ca923 100644 --- a/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h +++ b/third_party/tsingmicro/include/ExecutionEngine/CRunnerUtils.h @@ -50,11 +50,9 @@ constexpr unsigned nextPowerOf2(int n) { return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); } -template -struct Vector1D; +template struct Vector1D; -template -struct Vector1D { +template struct Vector1D { Vector1D() { static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]), "size error"); @@ -68,8 +66,7 @@ struct Vector1D { // 1-D vector, padded to the next power of 2 allocation. // Specialization occurs to avoid zero size arrays (which fail in -Werror). -template -struct Vector1D { +template struct Vector1D { Vector1D() { static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), @@ -86,8 +83,7 @@ struct Vector1D { } // namespace mlir // N-D vectors recurse down to 1-D. -template -struct Vector { +template struct Vector { inline Vector &operator[](unsigned i) { return vector[i]; } inline const Vector &operator[](unsigned i) const { return vector[i]; @@ -105,17 +101,14 @@ struct Vector mlir::detail::isPowerOf2(sizeof(T[Dim]))> { }; -template -using Vector1D = Vector; -template -using Vector2D = Vector; +template using Vector1D = Vector; +template using Vector2D = Vector; template using Vector3D = Vector; template using Vector4D = Vector; -template -void dropFront(int64_t arr[N], int64_t *res) { +template void dropFront(int64_t arr[N], int64_t *res) { for (unsigned i = 1; i < N; ++i) *(res + i - 1) = arr[i]; } @@ -123,12 +116,10 @@ void dropFront(int64_t arr[N], int64_t *res) { //===----------------------------------------------------------------------===// // Codegen-compatible structures for StridedMemRef type. //===----------------------------------------------------------------------===// -template -class StridedMemrefIterator; +template class StridedMemrefIterator; /// StridedMemRef descriptor type with static rank. -template -struct StridedMemRefType { +template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; @@ -165,8 +156,7 @@ struct StridedMemRefType { }; /// StridedMemRef descriptor type specialized for rank 1. -template -struct StridedMemRefType { +template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; @@ -188,8 +178,7 @@ struct StridedMemRefType { }; /// StridedMemRef descriptor type specialized for rank 0. -template -struct StridedMemRefType { +template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; @@ -207,8 +196,7 @@ struct StridedMemRefType { }; /// Iterate over all elements in a strided memref. -template -class StridedMemrefIterator { +template class StridedMemrefIterator { public: using iterator_category = std::forward_iterator_tag; using value_type = T; @@ -261,8 +249,7 @@ class StridedMemrefIterator { }; /// Iterate over all elements in a 0-ranked strided memref. -template -class StridedMemrefIterator { +template class StridedMemrefIterator { public: using iterator_category = std::forward_iterator_tag; using value_type = T; @@ -307,8 +294,7 @@ class StridedMemrefIterator { // Codegen-compatible structure for UnrankedMemRef type. //===----------------------------------------------------------------------===// // Unranked MemRef -template -struct UnrankedMemRefType { +template struct UnrankedMemRefType { int64_t rank; void *descriptor; }; @@ -316,12 +302,10 @@ struct UnrankedMemRefType { //===----------------------------------------------------------------------===// // DynamicMemRefType type. //===----------------------------------------------------------------------===// -template -class DynamicMemRefIterator; +template class DynamicMemRefIterator; // A reference to one of the StridedMemRef types. -template -class DynamicMemRefType { +template class DynamicMemRefType { public: int64_t rank; T *basePtr; @@ -388,8 +372,7 @@ class DynamicMemRefType { }; /// Iterate over all elements in a dynamic memref. -template -class DynamicMemRefIterator { +template class DynamicMemRefIterator { public: using iterator_category = std::forward_iterator_tag; using value_type = T; diff --git a/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td index a351a091b..e930ab73a 100644 --- a/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td +++ b/third_party/tsingmicro/include/magic-kernel-func/Dialect/IR/MagicKernelFuncOps.td @@ -16,4 +16,4 @@ // op, it is lowered to 2 MKF(MagicKernelFunc) which are integer version and // floating point version. // -//===----------------------------------------------------------------------===// \ No newline at end of file +//===----------------------------------------------------------------------===// diff --git a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td index 1053afae5..2dbad73eb 100644 --- a/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td +++ b/third_party/tsingmicro/include/magic-kernel-instr/Dialect/IR/MagicKernelInstrOps.td @@ -10,4 +10,4 @@ // The target glue layer that translates target independent kernel operations // into intrinsics which fits LLVM dialect lowering path. // -//===----------------------------------------------------------------------===// \ No newline at end of file +//===----------------------------------------------------------------------===// diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h index 1a2721d38..031593c26 100644 --- a/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/LinalgToMK/LinalgToMK.h @@ -12,9 +12,9 @@ #ifndef ZTC_CONVERSION_LINALG_TO_MK_H #define ZTC_CONVERSION_LINALG_TO_MK_H +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir { @@ -23,8 +23,7 @@ namespace triton { #define GEN_PASS_DECL #include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" -void populateLinalgToMKCanonicalizationPatterns( - RewritePatternSet &patterns); +void populateLinalgToMKCanonicalizationPatterns(RewritePatternSet &patterns); void populateLinalgToMKConversionPatterns(RewritePatternSet &patterns); @@ -33,4 +32,4 @@ std::unique_ptr> createLinalgToMKPass(); } // namespace triton } // namespace mlir -#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H \ No newline at end of file +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td index 0f8018678..666a7f414 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelAttrDefs.td @@ -12,4 +12,4 @@ include "mlir/IR/EnumAttr.td" -#endif // MAGIC_KERNEL_ATTR_DEFS \ No newline at end of file +#endif // MAGIC_KERNEL_ATTR_DEFS diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h index c9bb47440..06bd269a3 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.h @@ -18,7 +18,6 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/IR/Dialect.h" //===----------------------------------------------------------------------===// // MagicKernel Operations @@ -30,5 +29,4 @@ #define GET_OP_CLASSES #include "magic-kernel/Dialect/IR/MagicKernelOps.h.inc" - -#endif // MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ \ No newline at end of file +#endif // MLIR_DIALECT_MAGIC_KERNEL_IR_DIALECT_H_ diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td index ade7c11c3..4aee43bb8 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelDialect.td @@ -41,4 +41,4 @@ def MagicKernelDialect : Dialect { include "magic-kernel/Dialect/IR/MagicKernelTypes.td" -#endif // MAGIC_KERNEL_DIALECT \ No newline at end of file +#endif // MAGIC_KERNEL_DIALECT diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td index a9643ba9b..f49785fa5 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td @@ -281,4 +281,4 @@ def XorOp : MKBinElemWiseOp<"xor">; // def UmulhiOp : MKOp<"umulhi", [Pure]> {} -#endif // MAGIC_KERNEL_OPS \ No newline at end of file +#endif // MAGIC_KERNEL_OPS diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td index 7c87e5cf2..19fb9e1b6 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelTypes.td @@ -99,4 +99,4 @@ def MKTensorPtr : MKPtrOf<[MKTensor]>; // Any Type in Magic Kernel IR def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; -#endif // MAGIC_KERNEL_TYPES_TD \ No newline at end of file +#endif // MAGIC_KERNEL_TYPES_TD diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h index 2e32b7894..1a66d6ee2 100644 --- a/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h +++ b/third_party/tsingmicro/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -8,9 +8,9 @@ #ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H #define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include @@ -60,8 +60,10 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, const Location loc, OpBuilder &b); OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const arith::CmpIPredicate pred, const OpFoldResult trueVal, - const OpFoldResult falseVal, const Location loc, OpBuilder &b); + const arith::CmpIPredicate pred, + const OpFoldResult trueVal, + const OpFoldResult falseVal, const Location loc, + OpBuilder &b); } // namespace mlir #endif diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h index 1d8d2696c..82214faf0 100644 --- a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h @@ -8,11 +8,11 @@ // //===----------------------------------------------------------------------===// +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "triton-shared/Analysis/MaskAnalysis.h" #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "triton-shared/Analysis/PtrAnalysis.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -871,7 +871,8 @@ struct CallConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { SmallVector args = adaptor.getOperands(); - // We need to pass extra arguments added by addProgramInfo which are num_programs and program_ids + // We need to pass extra arguments added by addProgramInfo which are + // num_programs and program_ids if (FuncOp parentFunc = op->getParentOfType()) { SymbolRefAttr calleeAttr = op.getCalleeAttr(); StringRef calleeName = calleeAttr.getRootReference(); @@ -893,12 +894,12 @@ struct CallConverter : public OpConversionPattern { } } - auto call = rewriter.create( - op.getLoc(), op.getCallee(), op.getResultTypes(), args); + auto call = rewriter.create(op.getLoc(), op.getCallee(), + op.getResultTypes(), args); if (!call) { - op.emitError("Failed to create func::CallOp"); - return failure(); + op.emitError("Failed to create func::CallOp"); + return failure(); } rewriter.replaceOp(op, call); @@ -928,15 +929,17 @@ struct FpToFpConverter : public OpConversionPattern { auto resultWidth = getBitWidth(resultType); assert(operandWidth.has_value() && resultWidth.has_value() && - "Not a float-like operand or result"); + "Not a float-like operand or result"); if (operandWidth.value() > resultWidth.value()) { - Value truncatedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + Value truncatedValue = rewriter.create( + op.getLoc(), resultType, op.getOperand()); rewriter.replaceOp(op, truncatedValue); return success(); } - Value extendedValue = rewriter.create(op.getLoc(), resultType, op.getOperand()); + Value extendedValue = rewriter.create( + op.getLoc(), resultType, op.getOperand()); rewriter.replaceOp(op, extendedValue); return success(); @@ -951,8 +954,7 @@ struct ClampConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { bool propagateNan = op.getPropagateNan() == triton::PropagateNan::ALL; - assert(!propagateNan && - "PropagateNan is not supported"); + assert(!propagateNan && "PropagateNan is not supported"); Location loc = op.getLoc(); Value x = adaptor.getOperands()[0]; @@ -967,14 +969,15 @@ struct ClampConverter : public OpConversionPattern { } }; -struct PreciseSqrtConverter : public OpConversionPattern { +struct PreciseSqrtConverter + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto replacement = rewriter.create( - op.getLoc(), adaptor.getOperands()); + auto replacement = + rewriter.create(op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -987,8 +990,8 @@ struct PreciseDivConverter : public OpConversionPattern { LogicalResult matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto replacement = rewriter.create( - op.getLoc(), adaptor.getOperands()); + auto replacement = + rewriter.create(op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -1026,10 +1029,10 @@ struct SplitConverter : public OpConversionPattern { SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); - SmallVector sizes = - llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); + SmallVector sizes = llvm::to_vector( + llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); SmallVector results; @@ -1040,7 +1043,7 @@ struct SplitConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); Value slice = rewriter.create( - loc, resultTensor, input, offsets, sizes, strides); + loc, resultTensor, input, offsets, sizes, strides); results.push_back(slice); } @@ -1060,16 +1063,17 @@ struct JoinConverter : public OpConversionPattern { auto resultType = cast(op.getResult().getType()); auto loc = op.getLoc(); - Value result = rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + Value result = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); auto shape = resultType.getShape(); SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); - SmallVector sizes = - llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); + SmallVector sizes = llvm::to_vector( + llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); for (int i = 0; i < 2; ++i) { offsets.pop_back(); @@ -1077,7 +1081,8 @@ struct JoinConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); - result = rewriter.create(loc, inputs[i], result, offsets, sizes, strides); + result = rewriter.create(loc, inputs[i], result, + offsets, sizes, strides); } rewriter.replaceOp(op, result); @@ -1094,7 +1099,8 @@ struct MulHiUIOpConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto mulResult = rewriter.create(loc, adaptor.getOperands()); + auto mulResult = + rewriter.create(loc, adaptor.getOperands()); rewriter.replaceOp(op, mulResult.getHigh()); return success(); @@ -1110,29 +1116,29 @@ struct MatmulConverter : public OpConversionPattern { // true means tensor elements are zeros // false means not zero or it cannot be determined bool isZeroTensor(Value &v, bool integers) const { - if (auto splatOp = v.getDefiningOp()) { - if (auto constOp = splatOp.getSrc().getDefiningOp()) { - if (auto val = dyn_cast(constOp.getValue())) { - return val.getValueAsDouble() == 0.; - } - if (auto val = dyn_cast(constOp.getValue())) { - return val.getValue() == 0; - } + if (auto splatOp = v.getDefiningOp()) { + if (auto constOp = splatOp.getSrc().getDefiningOp()) { + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValueAsDouble() == 0.; + } + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValue() == 0; } - return false; } + return false; + } - if (auto constOp = v.getDefiningOp()) { - if (auto denseAttr = dyn_cast(constOp.getValue())) { - if (denseAttr.isSplat()) { - if (integers) - return denseAttr.getSplatValue().isZero(); - return denseAttr.getSplatValue().isZero(); - } + if (auto constOp = v.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat()) { + if (integers) + return denseAttr.getSplatValue().isZero(); + return denseAttr.getSplatValue().isZero(); } } + } - return false; + return false; } LogicalResult @@ -1149,9 +1155,10 @@ struct MatmulConverter : public OpConversionPattern { auto init = rewriter.create(loc, dstType.getShape(), elementType); - TypedAttr constantAttr = integers ? - static_cast(rewriter.getIntegerAttr(elementType, 0)) : - static_cast(rewriter.getFloatAttr(elementType, 0)); + TypedAttr constantAttr = + integers + ? static_cast(rewriter.getIntegerAttr(elementType, 0)) + : static_cast(rewriter.getFloatAttr(elementType, 0)); auto zero = rewriter.create( op.getLoc(), elementType, constantAttr); @@ -1160,8 +1167,8 @@ struct MatmulConverter : public OpConversionPattern { rewriter.create(loc, ValueRange{zero}, ValueRange{init}) .result(); - auto dotOp = rewriter.create( - loc, dstType, ValueRange{opa, opb, opc, zeroes}); + auto dotOp = rewriter.create(loc, dstType, + ValueRange{opa, opb, opc, zeroes}); rewriter.replaceOp(op, dotOp); @@ -1235,8 +1242,8 @@ struct ReduceConverter : public OpConversionPattern { bool requiresF32Conversion(const Type elemType, Operation *redOp) const { return isa(elemType) && elemType.getIntOrFloatBitWidth() < - llvm::cast(Float32Type::get(elemType.getContext())) - .getWidth() && + llvm::cast(Float32Type::get(elemType.getContext())) + .getWidth() && isa(redOp); } diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h index 279d671e7..29218895f 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h @@ -12,10 +12,10 @@ #ifndef ZTC_CONVERSION_MK_TO_TX81_H #define ZTC_CONVERSION_MK_TO_TX81_H +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir { @@ -24,8 +24,7 @@ namespace triton { #define GEN_PASS_DECL #include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h.inc" -void populateMKToTx81CanonicalizationPatterns( - RewritePatternSet &patterns); +void populateMKToTx81CanonicalizationPatterns(RewritePatternSet &patterns); void populateMKToTx81ConversionPatterns(RewritePatternSet &patterns); @@ -34,4 +33,4 @@ std::unique_ptr> createMKToTx81Pass(); } // namespace triton } // namespace mlir -#endif // ZTC_CONVERSION_MK_TO_TX81_H \ No newline at end of file +#endif // ZTC_CONVERSION_MK_TO_TX81_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h index 957c7e957..6e9b8147a 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h @@ -37,4 +37,4 @@ std::unique_ptr> createTx81MemrefToLLVMPass(); } // namespace triton } // namespace mlir -#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H \ No newline at end of file +#endif // ZTC_CONVERSION_MEMREF_TO_MAGICKERNEL_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h index d877e0287..3ee5ebdef 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h @@ -20,8 +20,8 @@ namespace mlir { class ModuleOp; class Pass; -/// Creates a pass that transforms kernel functions by replacing multiple arguments -/// with a single void* buffer argument. +/// Creates a pass that transforms kernel functions by replacing multiple +/// arguments with a single void* buffer argument. std::unique_ptr createKernelArgBufferPass(); #define GEN_PASS_DECL_KERNELARGBUFFERPASS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td index c69d3540d..07ca7549f 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81AttrDefs.td @@ -21,4 +21,4 @@ def RoundModeAttr : I32EnumAttr<"RoundMode", "Round mode", [ let cppNamespace = "::mlir::tx"; } -#endif // TSINGMICRO_TX81_ATTR_DEFS \ No newline at end of file +#endif // TSINGMICRO_TX81_ATTR_DEFS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h index 2fd0c9f34..955cbdb1e 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Dialect.h @@ -30,4 +30,4 @@ #include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" #include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" -#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H \ No newline at end of file +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h index ca9ba4bf9..dc27e2388 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.h @@ -23,4 +23,4 @@ #include "tsingmicro-tx81/Dialect/IR/Tx81Enums.h.inc" #include "tsingmicro-tx81/Dialect/IR/Tx81Ops.h.inc" -#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H \ No newline at end of file +#endif // MLIR_DIALECT_TSINGMICRO_TX81_IR_DIALECT_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td index 513baf79f..476817a79 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -772,4 +772,4 @@ def GatherScatter : Tx81Op<"gatherscatter", []> { let results = (outs UI64:$dst); } -#endif // TSINGMICRO_TX81_OPS \ No newline at end of file +#endif // TSINGMICRO_TX81_OPS diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td index e4c98254e..cce13dfe9 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Types.td @@ -104,4 +104,4 @@ def MKTensorPtr : MKPtrOf<[MKTensor]>; // Any Type in Magic Kernel IR def MKType : AnyTypeOf<[MKFloatLike, MKIntLike, MKPtrLike, MKTensorPtr]>; -#endif // TSINGMICRO_TX81_TYPES_TD \ No newline at end of file +#endif // TSINGMICRO_TX81_TYPES_TD diff --git a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp index dc8a27c45..ede7a7fd5 100644 --- a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp +++ b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp @@ -336,8 +336,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, // We only support sge against 0 for lower bounds. Dims already has an // implicit assumption that the lower bound is 0, so if we see this, assume // the comparison evaluates to true. - if (cmpOp.getPredicate() == arith::CmpIPredicate::sge - && !(rhsState.scalar && hasConstZero(rhsState.scalar))) { + if (cmpOp.getPredicate() == arith::CmpIPredicate::sge && + !(rhsState.scalar && hasConstZero(rhsState.scalar))) { InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi with rhs not equal to 0"; return failure(); @@ -370,10 +370,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, // should be loaded/stored by inserting a comparison + select: // dim = lhs < rhs ? lhs.dim : 0 newDim = compareOFRs(lhsState.scalar, rhsState.scalar, cmpOp.getPredicate(), - lhsState.dims[cmpDim], builder.getIndexAttr(0), - loc, builder); + lhsState.dims[cmpDim], builder.getIndexAttr(0), loc, + builder); } else if (cmpOp.getPredicate() == arith::CmpIPredicate::slt || - cmpOp.getPredicate() == arith::CmpIPredicate::ult) { + cmpOp.getPredicate() == arith::CmpIPredicate::ult) { // Important: // In the case where the values we are loading are entirely masked off like // the following: @@ -391,8 +391,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, newEnd = maxOFRs(newEnd, lhsState.start, loc, builder); newDim = subOFRs(newEnd, lhsState.start, loc, builder); } else { - assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && rhsState.scalar - && hasConstZero(rhsState.scalar)); + assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && + rhsState.scalar && hasConstZero(rhsState.scalar)); newDim = lhsState.dims[cmpDim]; } diff --git a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp index 62aa57ff9..a5efbc096 100644 --- a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp +++ b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp @@ -251,32 +251,34 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, } OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const arith::CmpIPredicate pred, const OpFoldResult trueOFR, - const OpFoldResult falseOFR, const Location loc, OpBuilder &b) { + const arith::CmpIPredicate pred, + const OpFoldResult trueOFR, + const OpFoldResult falseOFR, const Location loc, + OpBuilder &b) { auto lhsIntAttr = getIntAttr(lhs); auto rhsIntAttr = getIntAttr(rhs); // both lhs and rhs are constants, return the result directly if (lhsIntAttr && rhsIntAttr) { switch (pred) { - case arith::CmpIPredicate::eq: - return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR; - case arith::CmpIPredicate::ne: - return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR; - case arith::CmpIPredicate::slt: - case arith::CmpIPredicate::ult: - return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR; - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::ule: - return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR; - case arith::CmpIPredicate::sgt: - case arith::CmpIPredicate::ugt: - return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR; - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::uge: - return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR; - default: - llvm_unreachable("Unsupported predicate"); + case arith::CmpIPredicate::eq: + return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::ne: + return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR; + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR; + default: + llvm_unreachable("Unsupported predicate"); } } diff --git a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp index d813ffdb2..ee98c6c56 100644 --- a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp @@ -303,8 +303,8 @@ LogicalResult PtrState::mulState(const PtrState &lhsState, } if (lhsState.scalar && rhsState.scalar) { - scalar = builder.create( - loc, lhsState.scalar, rhsState.scalar); + scalar = + builder.create(loc, lhsState.scalar, rhsState.scalar); } for (uint64_t i = 0; i < lhs->sizes.size(); i++) { diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp index b088a56ea..a3970cf81 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -5,8 +5,8 @@ // //===----------------------------------------------------------------------===// -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #define DEBUG_TYPE "linalg-to-mk" @@ -16,7 +16,6 @@ using namespace mk; #define GEN_PASS_CLASSES #include "magic-kernel/Conversion/LinalgToMK/Passes.h.inc" - namespace { // Convert tensor.empty + linalg.fill + linalg.matmul to mk.matmul @@ -49,10 +48,9 @@ struct MatmulConverter : public OpConversionPattern { } // namespace void mlir::triton::populateLinalgToMKCanonicalizationPatterns( - RewritePatternSet &patterns) { -} + RewritePatternSet &patterns) {} void mlir::triton::populateLinalgToMKConversionPatterns( RewritePatternSet &patterns) { // patterns.add(patterns.getContext()); -} \ No newline at end of file +} diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp index 094a51b31..eaba1f34f 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp @@ -5,16 +5,16 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Support/Debug.h" +#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "magic-kernel/Conversion/LinalgToMK/LinalgToMK.h" -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" #include #include #include @@ -37,8 +37,8 @@ class LinalgToMKPass : public triton::impl::LinalgToMKBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { @@ -47,13 +47,12 @@ class LinalgToMKPass : public triton::impl::LinalgToMKBase { ConversionTarget target(getContext()); // TODO: Enable this when all conversion pattern has been implemented. - //target.addIllegalDialect(); + // target.addIllegalDialect(); - target.addLegalDialect< - func::FuncDialect, arith::ArithDialect, math::MathDialect, - affine::AffineDialect, scf::SCFDialect, - cf::ControlFlowDialect, tensor::TensorDialect, - mk::MagicKernelDialect>(); + target.addLegalDialect(); target.addLegalOp(); @@ -66,7 +65,6 @@ class LinalgToMKPass : public triton::impl::LinalgToMKBase { } // namespace -std::unique_ptr> -triton::createLinalgToMKPass() { +std::unique_ptr> triton::createLinalgToMKPass() { return std::make_unique(); } diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp index 611ba7be7..aa0c51160 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp @@ -49,8 +49,9 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { ConversionTarget target(getContext()); // Register illegal ops for Dialect Conversion - target.addIllegalDialect< linalg::LinalgDialect, - bufferization::BufferizationDialect, mk::MagicKernelDialect>(); + target.addIllegalDialect(); target.addLegalDialect back to tt.ptr type, these ops will then be // handled when we convert addptr op later. addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Value { + ValueRange inputs, Location loc) -> Value { return builder.create(loc, resultType, inputs) .getResult(0); }); addArgumentMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Value { + ValueRange inputs, Location loc) -> Value { return builder.create(loc, resultType, inputs) .getResult(0); }); @@ -117,8 +115,7 @@ class LoopTypeConverter : public TypeConverter { // Canonicalization will simplify this sequence by removing the inital // reinterpret_cast. addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType, - ValueRange inputs, - Location loc) -> Value { + ValueRange inputs, Location loc) -> Value { auto reinterpretCast = inputs[0].getDefiningOp(); return builder.create( @@ -201,8 +198,8 @@ buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input, return SmallVector{cast, zero}; } -static Value buildCastOp(OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { +static Value buildCastOp(OpBuilder &builder, Type resultType, ValueRange inputs, + Location loc) { assert(isa(resultType)); assert(inputs.size() && isa(inputs[0].getType()) && isa(inputs[1].getType())); @@ -238,7 +235,8 @@ class StructuredToMemrefPass patterns, typeConverter); target.addDynamicallyLegalOp([&](func::CallOp op) { - return typeConverter.isLegal(op.getResultTypes()) && typeConverter.isLegal(op.getOperandTypes()); + return typeConverter.isLegal(op.getResultTypes()) && + typeConverter.isLegal(op.getOperandTypes()); }); populateFunctionOpInterfaceTypeConversionPattern( @@ -336,7 +334,7 @@ class StructuredToMemrefPass // Compute the target materialization, given a value with the pointer type, // convert that value to a pair of {memref, index} type. -#if 0 // FIXME: Incompatible MILR interface +#if 0 // FIXME: Incompatible MILR interface converter.addTargetMaterialization(buildCastAndOffsetOps); #endif diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index 12e34c602..530c3242b 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -5,12 +5,12 @@ // //===----------------------------------------------------------------------===// +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp index 8085836bf..0f4b8dfc6 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp @@ -5,13 +5,13 @@ // //===----------------------------------------------------------------------===// +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" #include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" #include "triton-shared/Conversion/TritonToCoreDialects/TritonToCoreDialects.h" #include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index c479a06ef..71d694290 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -144,7 +144,7 @@ class TritonToStructuredPass // Compute the target materialization, given a value with the pointer type, // convert that value to a tuple type. -#if 0 // FIXME: Incompatible MILR interface +#if 0 // FIXME: Incompatible MILR interface converter.addTargetMaterialization( [](OpBuilder &builder, TypeRange resultTypes, Value input, Location loc) -> std::optional> { @@ -209,7 +209,7 @@ class TritonToStructuredPass // The return values for this op will be used as the init-args for scf.for. // At the end of pointer analysis, we will use the PtrState to create the // correct offsets, strides, and remove these ops. - #if 0 // FIXME: Incompatible MILR interface +#if 0 // FIXME: Incompatible MILR interface converter.addTargetMaterialization([](OpBuilder &builder, TypeRange resultTypes, Value input, Location loc) { diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp index 8a144b8a9..7f45b79c4 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp @@ -332,4 +332,4 @@ void mlir::triton::populateTx81MemrefToLLVMConversionPatterns( MemrefLoadOrStoreOpLowering>( converter); // clang-format on -} \ No newline at end of file +} diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp index 339197ffc..a024e67c8 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp @@ -41,7 +41,8 @@ class KernelArgBufferPass bool isKernelFunction(func::FuncOp func); // Creates a new function with a single void* argument - func::FuncOp createBufferizedFunction(OpBuilder &builder, func::FuncOp originalFunc); + func::FuncOp createBufferizedFunction(OpBuilder &builder, + func::FuncOp originalFunc); // Rewrites the function body to use the argument buffer void rewriteFunctionBody(func::FuncOp originalFunc, func::FuncOp newFunc); @@ -54,17 +55,18 @@ bool KernelArgBufferPass::isKernelFunction(func::FuncOp func) { return func.getName().contains("_kernel"); } -func::FuncOp KernelArgBufferPass::createBufferizedFunction(OpBuilder &builder, - func::FuncOp originalFunc) { +func::FuncOp +KernelArgBufferPass::createBufferizedFunction(OpBuilder &builder, + func::FuncOp originalFunc) { // Create a new function type with a single void* argument auto voidPtrType = LLVM::LLVMPointerType::get(builder.getContext()); - auto newFuncType = FunctionType::get(originalFunc.getContext(), - {voidPtrType}, - originalFunc.getFunctionType().getResults()); + auto newFuncType = + FunctionType::get(originalFunc.getContext(), {voidPtrType}, + originalFunc.getFunctionType().getResults()); // Create the new function with the same name but new type - auto newFunc = func::FuncOp::create(originalFunc.getLoc(), originalFunc.getName(), - newFuncType); + auto newFunc = func::FuncOp::create(originalFunc.getLoc(), + originalFunc.getName(), newFuncType); // Copy over all attributes except those related to the function type for (const auto &attr : originalFunc->getAttrs()) { @@ -78,7 +80,7 @@ func::FuncOp KernelArgBufferPass::createBufferizedFunction(OpBuilder &builder, } void KernelArgBufferPass::rewriteFunctionBody(func::FuncOp originalFunc, - func::FuncOp newFunc) { + func::FuncOp newFunc) { if (originalFunc.empty()) return; @@ -205,4 +207,4 @@ std::unique_ptr createKernelArgBufferPass() { namespace { #define GEN_PASS_REGISTRATION #include "KernelArgBufferPass.h.inc" -} // namespace \ No newline at end of file +} // namespace diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp index 66e198f54..667b0a8fa 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -963,7 +963,8 @@ struct LinalgFillOpConversion : public OpConversionPattern { Type elemType = tensorType.getElementType(); // Convert the tensor type to the LLVM pointer type - auto llvmPtrType = dyn_cast(typeConverter->convertType(tensorType)); + auto llvmPtrType = + dyn_cast(typeConverter->convertType(tensorType)); if (!llvmPtrType) { return rewriter.notifyMatchFailure( op, "failed to convert tensor type to LLVM pointer type"); diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp index 0507f6ab6..e0efab8d0 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVMPass.cpp @@ -5,20 +5,20 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Support/Debug.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Tx81ToLLVM.h" #include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" +#include "llvm/Support/Debug.h" #include #include #include @@ -33,13 +33,12 @@ using namespace triton; namespace { - class Tx81ToLLVMPass : public Tx81ToLLVMBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -74,7 +73,6 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { } // namespace -std::unique_ptr> -triton::createTx81ToLLVMPass() { +std::unique_ptr> triton::createTx81ToLLVMPass() { return std::make_unique(); } diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp index cd7cc2688..f0c256956 100644 --- a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp @@ -119,4 +119,4 @@ void mlir::mk::registerBufferizableOpInterfaceExternalModels( // TODO: Register all mk ops. MKOpInterfaceHelper::registerOpInterface(ctx); }); -} \ No newline at end of file +} diff --git a/third_party/tsingmicro/name.conf b/third_party/tsingmicro/name.conf index 0271c2eb3..8a593a129 100644 --- a/third_party/tsingmicro/name.conf +++ b/third_party/tsingmicro/name.conf @@ -1 +1 @@ -ztc \ No newline at end of file +ztc From a651061979b6be569ca1bb3789e284ccac323c40 Mon Sep 17 00:00:00 2001 From: zhzhcookie Date: Tue, 13 May 2025 16:44:57 +0800 Subject: [PATCH 03/12] [BACKEND] [BUILD] build tsingmicro backend with Triton 3.3.x (#2) pass --- CMakeLists.txt | 6 +++--- python/setup.py | 2 +- python/setup_helper.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ea64b2752..c328eea90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND}) -elseif(FLAGTREE_BACKEND STREQUAL "aipu") +elseif(FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro)$") add_definitions(-D__NVIDIA__) add_definitions(-D__AMD__) endif() @@ -204,7 +204,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu)$") +if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro)$") include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files add_subdirectory(include) @@ -446,7 +446,7 @@ find_package(Threads REQUIRED) add_subdirectory(third_party/f2reduce) -if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND STREQUAL "aipu") +if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro)$") add_subdirectory(bin) add_subdirectory(test) endif() diff --git a/python/setup.py b/python/setup.py index c9e623f9b..ed596d0a3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -597,7 +597,7 @@ def build_extension(self, ext): ) if helper.flagtree_backend: - if helper.flagtree_backend == "aipu": + if helper.flagtree_backend in ("aipu", "tsingmicro"): backends = [ *BackendInstaller.copy(helper.default_backends + helper.extend_backends), *BackendInstaller.copy_externals(), diff --git a/python/setup_helper.py b/python/setup_helper.py index fc99295fb..454c43b70 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -239,7 +239,7 @@ def skip_package_dir(package): @staticmethod def get_package_dir(packages): package_dict = {} - if flagtree_backend and flagtree_backend not in ("cambricon", "aipu"): + if flagtree_backend and flagtree_backend not in ("cambricon", "aipu", "tsingmicro"): connection = [] backend_triton_path = f"../third_party/{flagtree_backend}/python/" for package in packages: @@ -305,7 +305,7 @@ def handle_flagtree_backend(): if flagtree_backend: print(f"flagtree_backend is {flagtree_backend}") extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv and flagtree_backend != "aipu": + if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"): ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" default_backends.append("flir") if use_triton_shared: From 1e21d55d991fcc6c9e2dcd7af14039eeb808e09b Mon Sep 17 00:00:00 2001 From: zhzhcookie Date: Wed, 14 May 2025 12:47:18 +0800 Subject: [PATCH 04/12] [BACKEND] Fix tsingmicro backend code format (#3) --- .../tsingmicro/crt/include/Tx81/instr_def.h | 24 +++++++++---------- .../crt/include/Tx81/runtime/hrt_common.h | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_def.h b/third_party/tsingmicro/crt/include/Tx81/instr_def.h index c999997c0..2d73bf8c0 100644 --- a/third_party/tsingmicro/crt/include/Tx81/instr_def.h +++ b/third_party/tsingmicro/crt/include/Tx81/instr_def.h @@ -312,7 +312,7 @@ typedef struct Ncc_CT_GR_Ctl_Regs { // :stochastic round uint8_t src0_format; // 当CGRATensor_PeriOp_V_V_bit2fp指令,此字段用作dst_format - uint8_t opcode; // 详见CGRATensor指令OPcode.v + uint8_t opcode; // 详见CGRATensor指令OPcode.v } Ncc_CT_GR_Ctl_Regs; typedef struct Ncc_CT_GR_Param_Regs { @@ -320,18 +320,18 @@ typedef struct Ncc_CT_GR_Param_Regs { uint32_t src1; uint32_t dst0; uint32_t dst1; - uint32_t dst2; // spm地址 - uint64_t src0_tfr; // nhwc - uint64_t dst_tfr; // nhwc - uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) - uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy - uint64_t elem_count; // vector运算的元素个数 + uint32_t dst2; // spm地址 + uint64_t src0_tfr; // nhwc + uint64_t dst_tfr; // nhwc + uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) + uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy + uint64_t elem_count; // vector运算的元素个数 uint64_t unit_elem_count; // vector运算中的短向量的元素个数(最大为64) uint64_t int8_scale_val0; // 双线性插值x方向缩放系数(input_w/output_w) uint64_t int8_scale_val1; // 双线性插值y方向缩放系数(input_h/output_h) uint64_t int8_quant; // abandon uint32_t int8_bn_bias; // abandon - uint32_t full_elem_count; // 若干个src_elem_num之和 + uint32_t full_elem_count; // 若干个src_elem_num之和 uint32_t full_unit_elem_count; // 若干个src_uint_elem_num之和 uint64_t wb_data0; // The pointer of Return value. [32] DATA_VALID, [31:0] // data, 函数只有一个返回值时,返回数据写在此寄存器 @@ -374,7 +374,7 @@ typedef struct Ncc_NE_GR_Ctl_Regs { uint8_t input_format; uint8_t inpsum_en; uint8_t lrelu_en; // either relu or lrelu - uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum + uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum uint8_t scale_en; uint8_t bias_en; uint8_t dilation_conv; // valid as conv backwardconv @@ -424,7 +424,7 @@ typedef struct Ncc_NE_GR_Param_Regs { uint8_t quant_q0; // q2, (范围:0-31),[7:0] uint32_t sparse_index; // spm地址(稀疏化索引) - uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 + uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 uint32_t srcw_end; uint32_t psum_end; uint32_t bias_end; @@ -600,7 +600,7 @@ typedef struct Ncc_CSR_GR_PRIORITY_RW { } Ncc_CSR_GR_PRIORITY_RW; typedef struct Ncc_CSR_GR_EXCEPTION_MASK { - uint8_t exception_clear; // [49] 清中断寄存器(self-clear,无需清零) + uint8_t exception_clear; // [49] 清中断寄存器(self-clear,无需清零) uint8_t exception_update_enable; // [48] 1:保存最后一条异常 // 0:保存第一条异常,后续异常忽略 uint64_t exception_mask; // [47:0] 48'hffff_ffff_ffff 异常使能 1:中断源被屏蔽 @@ -933,7 +933,7 @@ typedef struct NCC_CSR { exception; //[7:0]SCALAR_EXCEPTION, [15:8]CT_EXCEPTION, //[23:16]NE_EXCEPTION, [31:24]RDMA_EXCEPTION, //[39:32]WDMA_EXCEPTION, [47:40]TDMA_EXCEPTION, [63:48]Reserved - uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved + uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, //[49]EXCEPTION_CLEAR, [63:49]Reserved uint64_t diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h index 9aed4ed61..4ee0d0e30 100644 --- a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h +++ b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h @@ -436,7 +436,7 @@ typedef struct TsmDevice { class TsmTensorData { public: TsmTensorData() : host_addr(0), device_addr(0), length(0) {} - ~TsmTensorData(){}; + ~TsmTensorData() {}; TsmHostPtr host_addr; TsmDevicePtr device_addr; From fcc926e7abbdb09e322f67cc4240907fb1bbdde8 Mon Sep 17 00:00:00 2001 From: tsingmicro-public Date: Wed, 4 Jun 2025 11:40:14 +0800 Subject: [PATCH 05/12] Sync code 6 4 (#8) * [BACKEND] tsingmicro run succ * [BACKEND] TEMP * [BACKEND] add tsingmicro backend depends * [BUILD] add cmake install call --- .gitignore | 2 + .pre-commit-config.yaml | 19 - CMakeLists.txt | 6 + README.md | 11 + README_cn.md | 11 + python/setup.py | 3 + third_party/tsingmicro/CMakeLists.txt | 29 + third_party/tsingmicro/backend/compiler.py | 145 ++- third_party/tsingmicro/backend/cpu_driver.py | 387 ------- third_party/tsingmicro/backend/driver.cpp | 570 +---------- third_party/tsingmicro/backend/driver.py | 846 ++++++---------- third_party/tsingmicro/backend/name.conf | 1 + .../tsingmicro/backend/txda_device.cpp | 180 ++++ third_party/tsingmicro/bin/CMakeLists.txt | 88 ++ .../tsingmicro/bin/RegisterTritonDialects.h | 181 ++++ .../tsingmicro/bin/tsingmicro-llvm-opt.cpp | 121 +++ third_party/tsingmicro/bin/tsingmicro-lsp.cpp | 10 + third_party/tsingmicro/bin/tsingmicro-opt.cpp | 11 + .../tsingmicro/bin/tsingmicro-reduce.cpp | 11 + .../bin/tsingmicro-tensor-layout.cpp | 232 +++++ third_party/tsingmicro/crt/CMakeLists.txt | 38 +- .../crt/include/Tx81/instr_adapter.h | 61 -- .../crt/include/Tx81/instr_adapter_plat.h | 873 ---------------- .../tsingmicro/crt/include/Tx81/instr_def.h | 949 ------------------ .../crt/include/Tx81/runtime/hrt_common.h | 488 --------- .../crt/include/Tx81/runtime/hrt_interface.h | 122 --- .../tsingmicro/crt/include/Tx81/tx81.h | 11 + third_party/tsingmicro/crt/lib/Tx81/abs.c | 32 + third_party/tsingmicro/crt/lib/Tx81/arith.c | 44 + third_party/tsingmicro/crt/lib/Tx81/count.c | 6 +- third_party/tsingmicro/crt/lib/Tx81/gemm.c | 22 +- third_party/tsingmicro/crt/lib/Tx81/logic.c | 78 ++ third_party/tsingmicro/crt/lib/Tx81/memset.c | 14 +- third_party/tsingmicro/crt/lib/Tx81/rdma.c | 8 +- .../tsingmicro/crt/lib/Tx81/relation.c | 144 +++ third_party/tsingmicro/crt/lib/Tx81/rsqrt.c | 34 + third_party/tsingmicro/crt/lib/Tx81/sqrt.c | 33 + third_party/tsingmicro/crt/lib/Tx81/wdma.c | 8 +- .../tsingmicro/examples/bare_matmul.py | 52 + third_party/tsingmicro/examples/benchmark.py | 65 ++ .../tsingmicro/examples/test_vec_add.py | 90 ++ .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.h | 3 + .../Conversion/Tx81ToLLVM/CMakeLists.txt | 4 + .../Tx81ToLLVM/KernelArgBufferPass.h | 8 +- .../Tx81ToLLVM/KernelArgBufferPass.td | 2 +- .../tsingmicro-tx81/Dialect/IR/Tx81Ops.td | 151 ++- .../lib/Analysis/OpFoldResultUtils.cpp | 9 +- .../lib/Conversion/MKToTx81/MKToTx81.cpp | 515 ++++++++-- .../lib/Conversion/MKToTx81/MKToTx81Pass.cpp | 63 ++ .../TritonArithToLinalg.cpp | 3 +- .../TritonArithToLinalgPass.cpp | 3 +- .../Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp | 56 +- .../Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp | 5 + .../lib/Conversion/Tx81ToLLVM/CMakeLists.txt | 2 + .../Tx81ToLLVM/KernelArgBufferPass.cpp | 233 ++--- .../lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp | 454 +++++++-- third_party/tsingmicro/name.conf | 2 +- .../tsingmicro/python/triton_tsingmicro.cc | 42 +- third_party/tsingmicro/scripts/build_llvm.sh | 29 + .../tsingmicro/scripts/build_tsingmicro.sh | 51 + third_party/tsingmicro/scripts/install.sh | 9 + .../tsingmicro/scripts/run_tsingmicro.sh | 42 + 62 files changed, 3219 insertions(+), 4503 deletions(-) delete mode 100644 third_party/tsingmicro/backend/cpu_driver.py create mode 100644 third_party/tsingmicro/backend/name.conf create mode 100644 third_party/tsingmicro/backend/txda_device.cpp create mode 100644 third_party/tsingmicro/bin/CMakeLists.txt create mode 100644 third_party/tsingmicro/bin/RegisterTritonDialects.h create mode 100644 third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp create mode 100644 third_party/tsingmicro/bin/tsingmicro-lsp.cpp create mode 100644 third_party/tsingmicro/bin/tsingmicro-opt.cpp create mode 100644 third_party/tsingmicro/bin/tsingmicro-reduce.cpp create mode 100644 third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp delete mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_adapter.h delete mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h delete mode 100644 third_party/tsingmicro/crt/include/Tx81/instr_def.h delete mode 100644 third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h delete mode 100644 third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h create mode 100644 third_party/tsingmicro/crt/lib/Tx81/abs.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/logic.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/relation.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/rsqrt.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/sqrt.c create mode 100644 third_party/tsingmicro/examples/bare_matmul.py create mode 100644 third_party/tsingmicro/examples/benchmark.py create mode 100644 third_party/tsingmicro/examples/test_vec_add.py create mode 100755 third_party/tsingmicro/scripts/build_llvm.sh create mode 100755 third_party/tsingmicro/scripts/build_tsingmicro.sh create mode 100755 third_party/tsingmicro/scripts/install.sh create mode 100755 third_party/tsingmicro/scripts/run_tsingmicro.sh diff --git a/.gitignore b/.gitignore index dd917eb2f..e15c39218 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,8 @@ third_party/cambricon/ third_party/iluvatar/iluvatarTritonPlugin.so third_party/triton_shared/ third_party/xpu/backend/xpu3 +third_party/tsingmicro/backend/lib +third_party/tsingmicro/backend/bin # Proton python/triton/profiler diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e6838c340..9fec66392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,25 +39,6 @@ repos: hooks: - id: clang-format - # Expand YAML anchors in files used by github workflows, because github can't - # do this itself. This lets us use anchors, which avoids code duplication. - - repo: local - hooks: - - id: expand-yaml-anchors - name: Expand YAML anchors - language: golang - additional_dependencies: [github.com/mikefarah/yq/v4@latest] - entry: > - bash -c ' - OUT=".github/workflows/integration-tests.yml" - IN="$OUT.in" - echo "# AUTOGENERATED by pre-commit, modify the .in file instead." > "$OUT" && - echo >> "$OUT" - yq "explode(.)" "$IN" >> "$OUT" - ' - files: ^.github/workflows/integration-tests.yml.* - pass_filenames: false - exclude: | (?x)( ^include/triton/external/| diff --git a/CMakeLists.txt b/CMakeLists.txt index c328eea90..f6380ecd0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -364,6 +364,12 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMXPUCodeGen LLVMXPUAsmParser ) + elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") + list(APPEND TRITON_LIBRARIES + # riscv + LLVMRISCVCodeGen + LLVMRISCVAsmParser + ) endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 diff --git a/README.md b/README.md index a5f354c6c..88054b36b 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,17 @@ export FLAGTREE_BACKEND=xpu python3 -m pip install . --no-build-isolation -v ``` ```shell +# tsingmicro +# Recommended: Use the Docker image (xxGB) https://xxxx +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/xxxx +wget https://github.com/FlagTree/flagtree/releases/download/xxxx +cd ${YOUR_CODE_DIR}/flagtree/ +./third_party/tsingmicro/scripts/install.sh +./third_party/tsingmicro/scripts/build_tsingmicro.sh +./third_party/tsingmicro/scripts/run_tsingmicro.sh third_party/tsingmicro/examples/test_vec_add.py +``` +```shell # mthreads # Recommended: Use the Dockerfile flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads diff --git a/README_cn.md b/README_cn.md index e205ad49b..aa5c4575a 100644 --- a/README_cn.md +++ b/README_cn.md @@ -43,6 +43,17 @@ export FLAGTREE_BACKEND=xpu python3 -m pip install . --no-build-isolation -v ``` ```shell +# tsingmicro +# Recommended: Use the Docker image (xxGB) https://xxxx +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/xxxx +wget https://github.com/FlagTree/flagtree/releases/download/xxxx +cd ${YOUR_CODE_DIR}/flagtree/ +./third_party/tsingmicro/scripts/install.sh +./third_party/tsingmicro/scripts/build_tsingmicro.sh +./third_party/tsingmicro/scripts/run_tsingmicro.sh third_party/tsingmicro/examples/test_vec_add.py +``` +```shell # mthreads # 推荐使用镜像 flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads diff --git a/python/setup.py b/python/setup.py index ed596d0a3..a84b896fd 100644 --- a/python/setup.py +++ b/python/setup.py @@ -432,6 +432,7 @@ def build_extension(self, ext): thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) + ext_base_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) # create build directories if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -471,6 +472,7 @@ def build_extension(self, ext): "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", + f"-DCMAKE_INSTALL_PREFIX={ext_base_dir}", ] # Note that asan doesn't work with binaries that use the GPU, so this is @@ -512,6 +514,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) + subprocess.check_call(["cmake", "--install", "."], cwd=cmake_dir) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt index 6643564da..8d0ee5de1 100644 --- a/third_party/tsingmicro/CMakeLists.txt +++ b/third_party/tsingmicro/CMakeLists.txt @@ -1,9 +1,38 @@ +if(NOT DEFINED TX8_HOME) + if(DEFINED ENV{TX8_HOME}) + set(TX8_HOME $ENV{TX8_HOME}) + else() + message(FATAL_ERROR "TX8_HOME environment variable is not defined") + endif() +endif() + +set(XUANTIE_NAME Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/crt/include) +include_directories(${TX8_HOME}/include) add_subdirectory(include) add_subdirectory(lib) +add_subdirectory(bin) +add_subdirectory(crt) if(TRITON_BUILD_PYTHON_MODULE) + # find_package(Python3 REQUIRED COMPONENTS Development Interpreter) # 添加查找 Python3 + # add_library(backendxxxTritonPlugin SHARED + # ${CMAKE_CURRENT_SOURCE_DIR}/triton_backendxxx.cc + # ) + # set_target_properties(backendxxxTritonPlugin PROPERTIES + # PREFIX "" + # LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + # POSITION_INDEPENDENT_CODE ON + # ) + # target_link_libraries(backendxxxTritonPlugin PRIVATE # 链接闭源模块,此处为两个⽰例 + # BackendTritonGPUToLLVM + # BackendTritonTransforms + # # Py + # ${Python3_LIBRARIES} # 添加链接 Python3 + # ${PYTHON_LDFLAGS} + # ) # FIXME: Unify the libraries for TsingMicro into fewer ones add_triton_plugin(TritonTsingMicro ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_tsingmicro.cc LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py index 16dbd8c9f..5a60dd12d 100644 --- a/third_party/tsingmicro/backend/compiler.py +++ b/third_party/tsingmicro/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes +from triton.runtime.cache import get_cache_manager from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -13,37 +14,26 @@ from pathlib import Path -def _get_ztc_opt_path() -> str: - path = os.getenv("ZTC_OPT_PATH", "") - if path == "": - raise Exception("ZTC_OPT_PATH is not set.") - return path - - -def _get_vendor_runtime_path() -> str: - path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") - if path == "": - raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") - return path +def _get_tsm_opt_path() -> str: + return os.path.join(os.path.dirname(__file__), "bin", "tsingmicro-opt") def _get_llvm_bin_path(bin_name: str) -> str: - path = os.getenv("LLVM_BINARY_DIR", "") + path = os.getenv("LLVM_SYSPATH", "") if path == "": - raise Exception("LLVM_BINARY_DIR is not set.") - return os.path.join(path, bin_name) + raise Exception("LLVM_SYSPATH is not set.") + return os.path.join(path, "bin", bin_name) -# The riscv c header files and libraries path. -def _get_libc_root() -> str: - path = os.getenv("LIB_C_ROOT", "") +def _get_tx8_path(sub_name: str) -> str: + path = os.getenv("TX8_HOME", "") if path == "": - raise Exception("LIB_C_ROOT is not set.") - return path + raise Exception("TX8_HOME is not set.") + return os.path.join(path, sub_name) def _dump_ir_if_needed(files): - path = os.getenv("ZTC_DUMP_PATH", "") + path = os.getenv("TRITON_DUMP_PATH", "") if not path: return @@ -52,6 +42,42 @@ def _dump_ir_if_needed(files): shutil.copy(f, os.path.join(path, os.path.basename(f))) +# Build a accelerator controller ELF +def compile_accelerator(): + # TODO : cache mechanism + # name = "npu_" + name + # key = hashlib.sha256(src.encode("utf-8")).hexdigest() + # cache = get_cache_manager(key) + # cache_path = cache.get_file(f"{name}.so") + + # if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + # FIXME: Hardcoded path + #dst_path = os.path.join(tmpdir, f"{name}.so") + dst_path = "/tmp/kernel.so" + libc_lib = os.path.join(_get_tx8_path("Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2"), "riscv64-unknown-elf", + "lib", "rv64imafdc", "lp64d") + libvr_path = os.path.join(os.path.dirname(__file__), "lib") + clang_path = _get_llvm_bin_path("clang") + lld_path = _get_llvm_bin_path("ld.lld") + tx8_lib = _get_tx8_path("lib") + subprocess.check_call([ + clang_path, "-shared", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-O2", + f"-fuse-ld={lld_path}", "-nostdlib", "-nostartfiles", "-Wl,--allow-shlib-undefined", "-mabi=lp64d", + "-Wl,--no-dynamic-linker", + # FIXME: Hardcoded path + "/tmp/kernel.o", f"-L{libvr_path}", f"-L{libc_lib}", f"-L{tx8_lib}", "-Wl,--whole-archive", + "-linstr_tx81", # Tx81 intrinsic API + "-lvr", # Wrapper API of Tx81 intrinsic + "-Wl,--no-whole-archive", "-lm", "-o", dst_path + ]) + + _dump_ir_if_needed([dst_path]) + with open(dst_path, 'rb') as f: + so = f.read() + return so + + def _ttir_to_coreir(mod): # Get Triton-MLIR as string ttir_code = str(mod) @@ -59,10 +85,10 @@ def _ttir_to_coreir(mod): src_path = os.path.join(tmpdir, "tt.mlir") dst_path = os.path.join(tmpdir, "core.mlir") Path(src_path).write_text(ttir_code) - ztc_opt_path = _get_ztc_opt_path() + tsm_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - ztc_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize", + tsm_opt_path, src_path, "--triton-to-core-dialects", "--one-shot-bufferize=allow-return-allocs-from-loops", #"--mlir-print-debuginfo", "-o", dst_path ]) @@ -81,10 +107,10 @@ def _coreir_to_mkir(mod): src_path = os.path.join(tmpdir, "core.mlir") dst_path = os.path.join(tmpdir, "mk.mlir") Path(src_path).write_text(coreir_code) - ztc_opt_path = _get_ztc_opt_path() + tsm_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - ztc_opt_path, src_path, "--core-dialects-to-mk", + tsm_opt_path, src_path, "--core-dialects-to-mk", #"--mlir-print-debuginfo", "-o", dst_path ]) @@ -103,12 +129,12 @@ def _coreir_to_txir(mod): src_path = os.path.join(tmpdir, "core.mlir") dst_path = os.path.join(tmpdir, "tx.mlir") Path(src_path).write_text(coreir_code) - ztc_opt_path = _get_ztc_opt_path() + tsm_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) subprocess.check_call([ - ztc_opt_path, src_path, "--expand-strided-metadata", "--mk-to-tx81", + tsm_opt_path, src_path, "--expand-strided-metadata", "--lower-affine", # convert affine.load to memref.load, need exec before tx81-to-llvm since we will support spm offset to memref.load - "--cse", # unused memref.subview/memref.reinterpret + "--mk-to-tx81", "--cse", # unused memref.subview/memref.reinterpret #"--mlir-print-debuginfo", "-o", dst_path ]) @@ -120,29 +146,50 @@ def _optimize_txir(txir: str): return txir -def _txir_to_llir(mod): +def _txir_to_llir(mod, metadata): txir_code = str(mod) with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "tx.mlir") llvmir_path = os.path.join(tmpdir, "ll.mlir") llir_path = os.path.join(tmpdir, "ll.ir") Path(src_path).write_text(txir_code) - ztc_opt_path = _get_ztc_opt_path() + tsm_opt_path = _get_tsm_opt_path() _dump_ir_if_needed([src_path]) # Tx81 and core dialects to LLVM-MLIR - subprocess.check_call([ - ztc_opt_path, src_path, "--tx81-memref-to-llvm", "--tx81-to-llvm", "--convert-scf-to-cf", - "--convert-math-to-llvm", "--convert-func-to-llvm", "--convert-cf-to-llvm", - # Use tx81-memref-to-llvm custom pass for now. - # "--finalize-memref-to-llvm", - "--convert-arith-to-llvm", # need exec last since arith.const conversion + args = [ + tsm_opt_path, src_path, + # Use tx81-memref-to-llvm to replace "--finalize-memref-to-llvm". + "--tx81-memref-to-llvm", "--convert-scf-to-cf", "--convert-math-to-llvm", + "--convert-cf-to-llvm", # need exec before "convert-func-to-llvm" + "--convert-func-to-llvm", # need exec before "kernel-arg-buffer", otherwise un-rank memref will translate to int(rank) + ptr + ] + + args.append( + "--kernel-arg-buffer" + ) # need exec before "tx81-to-llvm" which will declare other func. We want only replace the triton kernel + + # other pass + args += [ + "--tx81-to-llvm", "--convert-arith-to-llvm", # need exec last since arith.const conversion # Remove all unrealized casts created "--reconcile-unrealized-casts", "--canonicalize", #"--mlir-print-debuginfo", "-o", llvmir_path - ]) + ] + + subprocess.check_call(args) + _dump_ir_if_needed([llvmir_path]) + llvm_file = os.getenv("CUSTOMIZED_IR", "") + if (llvm_file != ""): + llvmir_path = os.getenv("TRITON_DUMP_PATH", "") + + if not llvmir_path: + return + + llvmir_path = os.path.join(llvmir_path, llvm_file) + # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") subprocess.check_call([mlir_translate_path, llvmir_path, "--mlir-to-llvmir", "-o", llir_path]) @@ -205,18 +252,18 @@ def _llir_to_bin(llir: str, metadata): Path(src_path).write_text(llir) clang_path = _get_llvm_bin_path("clang++") subprocess.check_call([ - clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-elf", "-march=rv64imafdc", "-o", + clang_path, src_path, "-O2", "-c", "-fPIC", "--target=riscv64-unknown-linux-gnu", "-march=rv64imafdc", "-o", dst_path ]) _dump_ir_if_needed([dst_path]) - with open(dst_path, 'rb') as f: - so = f.read() - return so + + # compile kernel and intrinsic wrapper to shared library + return compile_accelerator() @dataclass(frozen=True) -class CPUOptions: +class TXDAOptions: debug: bool = False arch: str = None num_warps: int = 0 @@ -229,6 +276,7 @@ class CPUOptions: shared: bool = False allow_fp8e4nv: bool = False allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + sanitize_overflow: bool = True def __post_init__(self): pass @@ -238,11 +286,11 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() -class CPUBackend(BaseBackend): +class TXDABackend(BaseBackend): @staticmethod def supports_target(target: GPUTarget): - return target.backend == 'cpu' + return target.backend == 'txda' def __init__(self, target: GPUTarget) -> None: super().__init__(target) @@ -250,10 +298,10 @@ def __init__(self, target: GPUTarget) -> None: def parse_options(self, opts) -> Any: args = {'arch': self.target.arch} - args.update({k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts}) - return CPUOptions(**args) + args.update({k: opts[k] for k in TXDAOptions.__dataclass_fields__.keys() if k in opts}) + return TXDAOptions(**args) - def get_codegen_implementation(self): + def get_codegen_implementation(self, options): codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)} return codegen_fns @@ -265,7 +313,6 @@ def pack_metadata(self, metadata): metadata.cluster_dims[1], metadata.cluster_dims[2], metadata.name) # Our compilation pipeline isn't in python like nvidia or amd, no need to load - # dialects. See `ztc.cc` def load_dialects(self, ctx): return @@ -288,7 +335,7 @@ def add_stages(self, stages, options): stages["coreir"] = lambda src, metadata: _optimize_coreir(_ttir_to_coreir(src)) # stages["mkir"] = lambda src, metadata: _optimize_mkir(_coreir_to_mkir(src)) stages["txir"] = lambda src, metadata: _optimize_txir(_coreir_to_txir(src)) - stages["llir"] = lambda src, metadata: _optimize_llir(_txir_to_llir(src)) + stages["llir"] = lambda src, metadata: _optimize_llir(_txir_to_llir(src, metadata)) stages["so"] = lambda src, metadata: _llir_to_bin(src, metadata) @functools.lru_cache() diff --git a/third_party/tsingmicro/backend/cpu_driver.py b/third_party/tsingmicro/backend/cpu_driver.py deleted file mode 100644 index 15631354b..000000000 --- a/third_party/tsingmicro/backend/cpu_driver.py +++ /dev/null @@ -1,387 +0,0 @@ -import hashlib -import tempfile -import sysconfig - -import os, subprocess, tempfile -import importlib.util -import sysconfig - -from pathlib import Path - -from triton.runtime.cache import get_cache_manager -from triton.backends.driver import DriverBase -from triton.backends.compiler import GPUTarget - - -# The riscv compiler -def _get_llvm_bin_path() -> str: - path = os.getenv("LLVM_BINARY_DIR", "") - if path == "": - raise Exception("LLVM_BINARY_DIR is not set.") - return path - - -# The riscv c header files and libraries path. -def _get_libc_root() -> str: - path = os.getenv("LIB_C_ROOT", "") - if path == "": - raise Exception("LIB_C_ROOT is not set.") - return path - - -# -------------------- Launcher ---------------------------- -def _ty_to_cpp(ty): - if ty[0] == '*': - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - -def _extracted_type(ty): - if ty[0] == '*': - return "PyObject*" - return _ty_to_cpp(ty) - - -def _format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - -def _generate_launcher(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiOOOO" + args_format - args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - - kernel_arg_decls = ', '.join( - _ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants) - kernel_arg_decls += ', ' if kernel_arg_decls else '' - - kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" - for i, ty in signature.items() - if i not in constants) - kernel_parameters += ', ' if kernel_parameters else '' - - return f""" -#include -#include -#include -#include "ExecutionEngine/CRunnerUtils.h" -#include "ExecutionEngine/CRunnerUtils.cpp" - -extern "C" {{ - // Pointer type (=Memref) becomes int64_t + MemRef struct - // FIXME: understand what this int64_t is used for. - void {kernel_name}({kernel_arg_decls} - int, int, int, int, int, int); -}} - -static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{ - if (gridX*gridY*gridZ > 0) {{ - // Cast "function" to the real function type. - for(int x = 0; x < gridX; x++) {{ - for(int y = 0; y < gridY; y++) {{ - for(int z = 0; z < gridZ; z++) {{ - // Use some random type "char" here. - {' '.join(f'StridedMemRefType ptr_arg{i} = {{static_cast(arg{i}), static_cast(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")} - {kernel_name}({kernel_parameters} - gridX, gridY, gridZ, x, y, z); - }} - }} - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - void *dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); - if(!ptr_info.dev_ptr) - return ptr_info; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - return ptr_info; -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - int gridX, gridY, gridZ; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {args_list})) {{ - return NULL; - }} - - // [CPULauncher-specific]: We don't need the metadata below but just put them - // here anyway to be consistent with others. - // This will make updating the driver easier in the future. - - // int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - // if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - // PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - // return NULL; - // }} - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); - - if (PyErr_Occurred()) {{ - return NULL; - }} - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__ztc_ref_cpu_kernel_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___ztc_ref_cpu_kernel_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - - -def compile_module(launcher_src, kernel_placeholder_name): - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - py_lib_dir = sysconfig.get_config_var("LIBDIR") - py_version = sysconfig.get_config_var("LDVERSION") - py_lib = '{name}{py_version}'.format(name="python", py_version=py_version) - cpu_backend_path = Path(__file__).resolve().parent - clang = os.path.join(_get_llvm_bin_path(), "clang++") - libc_inc = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "include") - libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib") - include_dir = os.path.join(cpu_backend_path, "include") - - def launch(gridX, gridY, gridZ, stream, cu_function, kernel_metadata, launch_metadata, launch_enter_hook, - launch_exit_hook, *args): - # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. - # Let's compile a kernel every time. - # The cu_function parameter actually contains our assembly source code. - # See CPUUtils.load_binary method. - asm_src = cu_function - kernel_name = kernel_metadata[6] # see pack_metadata in compiler.py - src = launcher_src.replace(kernel_placeholder_name, kernel_name) - - key = hashlib.md5(src.encode("utf-8") + asm_src).hexdigest() - cache = get_cache_manager(key) - name = "__ztc_ref_cpu_kernel_launcher" - filename = f"{name}.so" - cache_path = cache.get_file(filename) - - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - asm_src_path = os.path.join(tmpdir, "kernel.s") - launcher_src_path = os.path.join(tmpdir, "main.cxx") - so_path = os.path.join(tmpdir, "kernel.so") - Path(asm_src_path).write_bytes(asm_src) - Path(launcher_src_path).write_text(src) - # Compile it together. - subprocess.check_call([ - clang, "-std=c++17", "--target=riscv64-unknown-elf", launcher_src_path, asm_src_path, - f"-I{libc_inc}", f"-I{py_include_dir}", f"-I{include_dir}", f"-I{libc_lib}", f"-L{py_lib_dir}", - "-shared", f"-l{py_lib}", "-fPIC", "-o", so_path - ]) - - with open(so_path, "rb") as f: - cache_path = cache.put(f.read(), filename, binary=True) - - # Load and launch the compiled kernel. - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod.launch(gridX, gridY, gridZ, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, - *args) - - return launch - - -class CPULauncher(object): - - def __init__(self, src, metadata): - kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" - - constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) - # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name - # in the following launch function. - self.launch = compile_module(launcher_src, kernel_placeholder_name) - - def __call__(self, *args, **kwargs): - self.launch(*args, **kwargs) - - -class CPUUtils(object): - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(CPUUtils, cls).__new__(cls) - return cls.instance - - # Note: - # nvidia and amd backends have their corresponding driver.c file that exposes - # get_device_properties and load_binary using python bindings. - # (see third_party/nvidia/backend/driver.c) - # These methods are then used in compiler.py to initialize handles before running - # the triton kernels. - # Since we recompile the kernel every time (see compile_module above), - # and the metadata generated by these functions aren't applicable to the cpu - # backend, just define the same functions with dummy implementation. - @staticmethod - def get_device_properties(device): - return { - "max_shared_mem": 2**20, "multiprocessor_count": None, "sm_clock_rate": None, "mem_clock_rate": None, - "mem_bus_width": None - } - - # Important note: - # Since we cannot easy pass function pointers around, we pass along the - # assembly source code so that compile_module above can recompile the - # module every time. - @staticmethod - def load_binary(name, kernel_asm, shared, device): - return (None, # module - kernel_asm, # function - None, # n_regs - None # n_spills - ) - - -class CPUDriver(DriverBase): - - def __init__(self): - super().__init__() - self.utils = CPUUtils() - self.launcher_cls = CPULauncher - self.binary_ext = "cpuasm" - - # CPU driver won't be automatically chosen unless explicitly set through - # triton.runtime.driver.set_active(CPUDriver()) - @staticmethod - def is_active(): - return False - - def get_device_capability(self): - return ("cpu", 0) - - def get_current_stream(self, device): - return None - - def get_current_device(self): - # CPU doesn't have a device to return. Return something. - return "cpu" - - def set_current_device(self, device): - # CPU doesn't have a device to set - assert device == "cpu" - return - - def get_current_target(self): - return GPUTarget("cpu", 0, 0) - - def assemble_tensormap_to_arg(self, tensormaps_info, args): - return args diff --git a/third_party/tsingmicro/backend/driver.cpp b/third_party/tsingmicro/backend/driver.cpp index 242a4433b..a9af39cfc 100644 --- a/third_party/tsingmicro/backend/driver.cpp +++ b/third_party/tsingmicro/backend/driver.cpp @@ -18,311 +18,16 @@ #include #include -struct Kernel_Param // Triton kernel arguments -{ - uint32_t gridX; - uint32_t gridY; - uint32_t gridZ; - // TODO... -}; - -struct Kernel_Head { - uint32_t param_type; - uint32_t param_num; - uint32_t param_addr; - uint32_t xxxxx; -}; - -// Raises a Python exception and returns false if code is not RET_SUCCESS. -static bool tsmAssert(TSM_RETCODE code, const char *file, int line) { - if (code == RET_SUCCESS) - return true; - - const char *prefix = "Triton Error [TX81]: "; - const char *str; - - // Map error codes to strings - switch (code) { - case RET_ERROR: - str = "General error"; - break; - case RET_PARAM1_ERROR: - case RET_PARAM2_ERROR: - case RET_PARAM3_ERROR: - str = "Parameter error"; - break; - case RET_DEVICE_OFFLINE: - str = "Device offline"; - break; - case RET_DEVICE_NOMEM: - str = "Device out of memory"; - break; - case RET_DEVICE_IN_IDLE: - str = "Device in idle state"; - break; - case RET_DEVICE_IN_ATTACH: - str = "Device already attached"; - break; - case RET_DEVICE_ATTACH_SUCCESS: - str = "Device attach success"; - break; - case RET_DEVICE_ATTACH_READY: - str = "Device attach ready"; - break; - case RET_DEVICE_LOSE_CONNECT: - str = "Device connection lost"; - break; - case RET_ENV_CLEAN_UP: - str = "Environment cleanup required"; - break; - default: - str = "Unknown error"; - } - - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - return false; -} - -static void prepare_input(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) { - for (uint32_t i = 0; i < chip_info->input_num; ++i) { - chip_info->input_dev_addr.push_back(0); - if (TsmDeviceMalloc(devices[dev_index], chip_info->input_dev_addr[i], - chip_info->input_size[i]) != RET_SUCCESS) { - printf("[Chip id %u] Input%d, DeviceMalloc failed!\n", - devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; - } - - if (TsmMemcpyH2D((TsmDevicePtr)chip_info->input_dev_addr[i], - (void *)chip_info->input_host_addr[i], - chip_info->input_size[i]) != RET_SUCCESS) { - printf("[Chip id %u] Input%d, MemcpyH2D failed!\n", - devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; - } - } -} - -static void prepare_output(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) { - for (size_t i = 0; i < chip_info->output_num; ++i) { - chip_info->output_dev_addr.push_back(0); - printf("[Chip id %u] output[%lu] data(size: %lu)\n", - devices[dev_index]->chip_id, i, chip_info->output_size[i]); - - if (TsmDeviceMalloc(devices[dev_index], chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) { - printf("[Chip id %u] output[%lu], DeviceMalloc failed!\n", - devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; - } - } -} - -TSM_RETCODE kernel_result_process(std::vector devices, - uint32_t dev_index, - std::shared_ptr hostboot, - std::shared_ptr chip_info, - TsmDevicePtr bootpm_dev, - std::string case_dir) { - for (size_t i = 0; i < chip_info->output_num; ++i) { - // 动态shape,需要处理真实的output size - if (TsmMemcpyD2H(hostboot->get_bootpmbuffer(), bootpm_dev, - hostboot->get_maxlen()) != RET_SUCCESS) { - return RET_ERROR; - } - - auto out_tensor = hostboot->get_dev_output_tensor_after_run(i); - chip_info->output[i]->dim = out_tensor->dim; - std::memcpy(chip_info->output[i]->shape, out_tensor->shape, - sizeof(out_tensor->shape)); - chip_info->output_size[i] = - hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); - for (uint32_t j = 0; j < out_tensor->dim; ++j) { - if (out_tensor->shape[j] > 0) { - chip_info->output_size[i] *= out_tensor->shape[j]; - } - } - - TsmHostPtr output_host_addr = (TsmHostPtr)malloc(chip_info->output_size[i]); - if (chip_info->output_size[i] > 0) { - if (TsmMemcpyD2H((void *)output_host_addr, chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) { - return RET_ERROR; - } - } - - printf("[Chip id %u] output_dev_addr=%ld\n", devices[dev_index]->chip_id, - chip_info->output_dev_addr[i]); - - // TODO: Processing output -#if 0 - std::string file_path = case_dir + "/chip" + std::to_string(dev_index) + - "/agent/data/out" + std::to_string(i) + "_riscv.bin"; - saveDataToFile(file_path, output_host_addr, chip_info->output_size[i]); -#endif - - if (output_host_addr != 0) { - free((void *)output_host_addr); - } - } - return RET_SUCCESS; -} - -TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) { - if (bootpm_dev != 0) { - printf("[Chip id %u] bootpm dev addr: 0x%lx \n", chip_id, bootpm_dev); - if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) { - return RET_ERROR; - } - bootpm_dev = 0; - } - - return RET_SUCCESS; -} - -static void setHostBoot(std::shared_ptr &chip_info, - std::shared_ptr &hostboot) { - if (chip_info == nullptr) { - printf("chip_info is null.\n"); - return; - } - - if (hostboot == nullptr) { - printf("hostboot is null.\n"); - return; - } - - for (size_t i = 0; i < chip_info->input_dev_addr.size(); ++i) { - hostboot->set_dev_input(i, chip_info->input_dev_addr[i], - chip_info->input_size[i]); - hostboot->set_dev_input_tensor(i, chip_info->input[i]); - } - - for (size_t i = 0; i < chip_info->output_dev_addr.size(); ++i) { - hostboot->set_dev_output(i, chip_info->output_dev_addr[i], - chip_info->output_size[i]); - } - - for (size_t i = 0; i < chip_info->param_num; ++i) { - hostboot->set_dev_param(i, chip_info->param_dev_addr[i], - chip_info->param_size[i]); - } - - return; -} - -// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. -#define TSM_CHECK_AND_RETURN_NULL(ans) \ - do { \ - if (!tsmAssert((ans), __FILE__, __LINE__)) \ - return NULL; \ - } while (0) - -// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. -#define TSM_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ - do { \ - if (!tsmAssert((ans), __FILE__, __LINE__)) { \ - PyEval_RestoreThread(_save); \ - return NULL; \ - } \ - } while (0) - -// Global state for Tx81 devices -static std::vector g_tx81_devices; -static bool g_runtime_initialized = false; - -// Initialize the Tx81 runtime if not already initialized -static bool init_tx81_runtime_if_needed() { - if (g_runtime_initialized) { - return true; - } - - // Initialize the Tx81 runtime - if (TsmInitRuntime() != RET_SUCCESS) { - PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 runtime"); - return false; - } - - // Get device count - uint32_t device_num = 0; - if (TsmGetDeviceNum(device_num) != RET_SUCCESS || device_num == 0) { - PyErr_SetString(PyExc_RuntimeError, - "Failed to get Tx81 device count or no devices found"); - TsmDeInitRuntime(); - return false; - } - - // Set up devices - for simplicity, we're using a 1x1 configuration - uint32_t first_phy_id = 0; - uint32_t card_x = 1; - uint32_t card_y = 1; - - if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != - RET_SUCCESS) { - PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); - TsmDeInitRuntime(); - return false; - } - - g_runtime_initialized = true; - return true; -} - static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { -#if 0 - // FIXME: Extracting device_id - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - - // Initialize the runtime if needed - if (!init_tx81_runtime_if_needed()) { - return NULL; - } - - // Check device ID is valid - if (device_id < 0 || (size_t)device_id >= g_tx81_devices.size()) { - PyErr_SetString(PyExc_ValueError, "Invalid device ID"); - return NULL; - } - - // Get device handle - TsmDevice* device = g_tx81_devices[device_id]; - - // Get device information - TsmDeviceInfo info; - memset(&info, 0, sizeof(TsmDeviceInfo)); - TSM_CHECK_AND_RETURN_NULL(TsmGetDeviceInfo(&info)); -#endif // Extract device properties // Note: We're mapping Tx81 properties to fields expected by Triton - int max_shared_mem = 1024 * 1024 * 4; // Default 4MB + int max_shared_mem = 1024 * 1024 * 3; // Default 3MB // int multiprocessor_count = device->tile_num; int multiprocessor_count = 1; int sm_clock_rate = 1000; // Placeholder int mem_clock_rate = 2000; // Placeholder int mem_bus_width = 256; // Placeholder -#if 0 - // For the specified device, get more detailed info - if (device_id < (int)info.card_num) { - CardComputeInfo& card_info = info.card_compute_info[device_id]; - multiprocessor_count = card_info.all_tile_num; - } -#endif - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem, "multiprocessor_count", multiprocessor_count, "sm_clock_rate", sm_clock_rate, @@ -336,80 +41,6 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { Py_ssize_t data_size; int shared; int device; -#if 0 - if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, - &device)) { - return NULL; - } - - // Initialize the runtime if needed - if (!init_tx81_runtime_if_needed()) { - return NULL; - } - - // Check device ID is valid - if (device < 0 || (size_t)device >= g_tx81_devices.size()) { - PyErr_SetString(PyExc_ValueError, "Invalid device ID"); - return NULL; - } - - TsmDevice* tx81_device = g_tx81_devices[device]; - - // First, we need to write binary data to a temporary file - char temp_path[256]; - sprintf(temp_path, "/tmp/triton_tx81_kernel_XXXXXX"); - int fd = mkstemp(temp_path); - if (fd == -1) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create temporary file"); - return NULL; - } - - // Write the kernel data to the temporary file - if (write(fd, data, data_size) != data_size) { - close(fd); - unlink(temp_path); - PyErr_SetString(PyExc_RuntimeError, - "Failed to write kernel data to temporary file"); - return NULL; - } - close(fd); - - // Create a model structure, the compiled kernel.so is specified via case_dir - // and the name of the entry function is specified via case_name. - TsmModel *model = new TsmModel(); - model->case_dir = std::string(temp_path); - model->case_name = std::string(name); - - // Set compile options - CompileOption compl_option = {}; - compl_option.comp_enable = 0; // Use precompiled kernel.so instead - compl_option.chip_x = 1; - compl_option.chip_y = 1; - compl_option.check_enable = true; - compl_option.enable_kcore_bin = 1; - compl_option.enable_kcore_so = 1; - - std::vector devices = {tx81_device}; - - // Not really compile the kernel, as kernel is already compiled, so this - // runtime API only configs the data structure of device firmware and the - // information of the program and data that runs on it. - Py_BEGIN_ALLOW_THREADS; - TSM_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - TsmCompileMultiGraph(devices, *model, "", compl_option)); - Py_END_ALLOW_THREADS; - - // For Tx81, we'll use a simpler model than CUDA - // We return a pointer to the TsmModel, which is analogous to CUmodule - // For the function pointer, we'll use model_id+0, which will be interpreted - // in the launcher code - // n_regs and n_spills are placeholders for now - int32_t n_regs = 256; // Default/placeholder value - int32_t n_spills = 0; // Default/placeholder value - - // Clean up the temporary file - unlink(temp_path); -#endif int32_t n_regs = 256; int32_t n_spills = 0; @@ -418,210 +49,11 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { n_spills); } -static PyObject *launch(PyObject *self, PyObject *args) { - std::vector devices; - // TODO:通过参数传递获取device信息 - - // 需要的输入信息: devices, case_dir(按固定路径存放的kernelso), - // input_host_addr/input_size/input_num, - // output_host_addr/output_size/output_num, param信息(如果有权重) - - TsmModel *new_model = new TsmModel(); // 设备相关参数已在dev中 - std::string option = "-O2"; - CompileOption compl_option = {}; - compl_option.comp_enable = 0; - compl_option.chip_x = 1; // 单卡 - compl_option.chip_y = 1; - compl_option.check_enable = true; - compl_option.enable_kcore_bin = 1; - compl_option.enable_kcore_so = 1; - new_model->case_dir = - "/tmp/todo"; // 参数传入, kernelso路径,同streambin/kcorebin文件夹路径 - - if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != - RET_SUCCESS) { - for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { - if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) { - printf("[Chip id %u] tx_engine: tx_reset, failed!\n", dev_index); - } else { - printf("[Chip id %u] tx_engine: tx_reset, success!\n", dev_index); - } - } - printf("TsmCompile failed.\n"); - return NULL; - } - - std::vector kmodel_vec = {new_model}; - - uint32_t input_num = 2; // TODO:根据kernel参数填写 - uint32_t output_num = 1; // TODO:根据kernel参数填写 - uint32_t param_num = 0; // 权重数 - std::shared_ptr hostboot = - std::make_shared(input_num, output_num, param_num); - - std::shared_ptr chip_info; - // 填充chipinfo信息 - chip_info->input_num = input_num; - chip_info->output_num = output_num; - chip_info->param_num = param_num; - chip_info->imm_size = 0; // 缓存大小暂设置为0,和算子实际相关; - // chip_info->tile_num = 16; // 未使用 - // chip_info->tile_x = 4; // 未使用 - // chip_info->tile_y = 4; // 未使用 - for (uint32_t i = 0; i < chip_info->input_num; ++i) { - chip_info->input_size[i] = 6; // TODO:填写实际输入大小 - chip_info->input_host_addr = std::vector{ - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; // TODO: 填写实际输入地址 - } - - for (uint32_t i = 0; i < chip_info->output_num; ++i) { - chip_info->output_size[i] = 1; // TODO:填写实际输出大小 - chip_info->output_host_addr = - std::vector{0x0}; // TODO: 填写实际输出地址 - } - - // for(uint32_t i = 0; i < chip_info->param_num; ++i) { - // chip_info->param_size[i] = 0; // TODO:填写实际权重大小 - // chip_info->param_host_addr = 0x0; - // } - - // prepare data/ load kernel/run/unload kernel/get out data/release memory - for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) { - // input prepare - prepare_input(devices, dev_index, chip_info); - // output prepare - prepare_output(devices, dev_index, chip_info); - - uint32_t chip_id = devices[dev_index]->chip_id; - TsmSetMonitorInfo(devices[dev_index]); - - // load kernel - char module_symbol[] = "main_kernel"; - TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); - printf("TsmLoadKernel finish!...\n"); - - printf("[Chip id %u] Set boot-params...\n", chip_id); - size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); - TsmDevicePtr dev_dyn_mods_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != - RET_SUCCESS) { - return NULL; - } - TsmDevicePtr dev_tlv_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_tlv_ptr, - sizeof(DynTLV_DynMods)) != RET_SUCCESS) { - return NULL; - } - - TsmDevicePtr dev_kernel_head_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_kernel_head_ptr, - sizeof(Kernel_Head)) != RET_SUCCESS) { - return NULL; - } - TsmDevicePtr dev_kernel_param_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_kernel_param_ptr, - sizeof(Kernel_Param)) != RET_SUCCESS) { - return NULL; - } - - Kernel_Head *host_kernel_head_ptr = - (Kernel_Head *)malloc(sizeof(Kernel_Head)); - Kernel_Param *host_kernel_param_ptr = - (Kernel_Param *)malloc(sizeof(Kernel_Param)); - - host_kernel_head_ptr->param_type = 1; - host_kernel_head_ptr->param_num = 1; // Number of kernel arguments - host_kernel_head_ptr->param_addr = - dev_kernel_param_ptr; // 将kernel 使用的参数地址赋值 - - // TODO: Setup the triton kernel arguments - host_kernel_param_ptr->gridX = 512; - host_kernel_param_ptr->gridY = 512; - host_kernel_param_ptr->gridZ = 512; - - TsmMemcpyH2D(dev_kernel_head_ptr, host_kernel_head_ptr, - sizeof(Kernel_Head)); - TsmMemcpyH2D(dev_kernel_param_ptr, host_kernel_param_ptr, - sizeof(Kernel_Param)); - - free(host_kernel_head_ptr); - free(host_kernel_param_ptr); - - // TODO: No such API - setHostBoot(chip_info, hostboot); - set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, dev_tlv_ptr, - dev_kernel_head_ptr); - - TsmDevicePtr bootpm_dev; - if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, - hostboot->get_maxlen()) != RET_SUCCESS) { - return NULL; - } - if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), - hostboot->get_maxlen()) != RET_SUCCESS) { - return NULL; - } - - if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) { - printf("TsmRun bootpm_dev failed.\n"); - return NULL; - } - - // 卸载kernel - TsmUnloadKernel(devices[dev_index], kmodel_vec); - - // 得到输出数据,并进行处理 - printf("[Chip id %u] Copy output from device...\n", chip_id); - if (kernel_result_process(devices, dev_index, hostboot, chip_info, - bootpm_dev, new_model->case_dir) != RET_SUCCESS) { - printf("free dev memory failed.\n"); - return NULL; - } - if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) { - printf("free dev memory failed.\n"); - return NULL; - } - // 释放多图相关tlv - if (TsmDeviceFree(dev_kernel_head_ptr) != RET_SUCCESS) { - printf("free dev_kernel_head_ptr failed.\n"); - return NULL; - } - if (TsmDeviceFree(dev_kernel_param_ptr) != RET_SUCCESS) { - printf("free dev_kernel_param_ptr failed.\n"); - return NULL; - } - - if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) { - printf("free dev_dyn_mods_ptr failed.\n"); - return NULL; - } - if (TsmDeviceFree(dev_tlv_ptr) != RET_SUCCESS) { - printf("free dev_tlv_ptr failed.\n"); - return NULL; - } - - printf("[dev_index %u] Set Terminal Info...\n", dev_index); - if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) { - printf("TsmSetTerminate failed.\n"); - return NULL; - } -#if 0 - if (freeTensorData(chip_id, chip_info) != RET_SUCCESS) { - printf("free tensor data dev memory failed.\n"); - } -#endif - } - - Py_RETURN_NONE; -} - static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided binary into Tx81 driver"}, - {"launch", launch, METH_VARARGS, "tx8 launch kernel!"}, {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given Tx81 device"}, - {NULL, NULL, 0, NULL} // sentinel }; diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py index 071b4a1b0..05e637c04 100644 --- a/third_party/tsingmicro/backend/driver.py +++ b/third_party/tsingmicro/backend/driver.py @@ -11,48 +11,59 @@ import importlib.util import shutil import sysconfig +import atexit from pathlib import Path -from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager from triton.backends.driver import GPUDriver from triton.backends.compiler import GPUTarget +import torch +from torch.utils import cpp_extension, rename_privateuse1_backend, generate_methods_for_privateuse1_backend + +module = cpp_extension.load( + name="txda", + sources=[os.path.dirname(__file__) + "/txda_device.cpp"], + #runtime include path + extra_include_paths=[""], + #runtime *.so path + extra_ldflags=[""], + extra_cflags=["-g"], + verbose=True, +) + +torch.utils.rename_privateuse1_backend("txda") + +torch._register_device_module("txda", module) + +generate_methods_for_privateuse1_backend(for_storage=True) + + +def _get_tx8_path(bin_name: str) -> str: + path = os.getenv("TX8_HOME", "") + if path == "": + raise Exception("TX8_HOME is not set.") + return os.path.join(path, bin_name) + + dirname = os.path.dirname(os.path.realpath(__file__)) include_dirs = [ os.path.join(dirname, "include"), + os.path.realpath(_get_tx8_path("include")), os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), os.path.join(sysconfig.get_path('platlib'), "torch", "include"), os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), os.path.join(sysconfig.get_path('platlib'), "numpy", "_core", "include") ] -library_dirs = [os.path.join(dirname, "lib"), os.path.join(sysconfig.get_path('platlib'), "torch", "lib")] +library_dirs = [ + os.path.join(dirname, "lib"), + os.path.realpath(_get_tx8_path("lib")), + os.path.join(sysconfig.get_path('platlib'), "torch", "lib") +] libraries = ['tx8_runtime', 'torch', 'torch_cpu', 'torch_python', 'c10'] -# Path configuration for cross compilation -def _get_llvm_bin_path(bin_name: str) -> str: - path = os.getenv("LLVM_BINARY_DIR", "") - if path == "": - raise Exception("LLVM_BINARY_DIR is not set.") - return os.path.join(path, bin_name) - - -def _get_libc_root() -> str: - path = os.getenv("LIB_C_ROOT", "") - if path == "": - raise Exception("LIB_C_ROOT is not set.") - return path - - -def _get_vendor_runtime_path() -> str: - path = os.getenv("LIB_VENDOR_RUNTIME_PATH", "") - if path == "": - raise Exception("LIB_VENDOR_RUNTIME_PATH is not set.") - return path - - def _dump_ir_if_needed(files): - path = os.getenv("ZTC_DUMP_PATH", "") + path = os.getenv("TRITON_DUMP_PATH", "") if not path: return @@ -61,6 +72,39 @@ def _dump_ir_if_needed(files): shutil.copy(f, os.path.join(path, os.path.basename(f))) +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + cc = clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-std=c++17", "-Wno-psabi", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + cc_cmd += [f"-Wl,-rpath,{dir}" for dir in library_dirs] + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + return so + + # Build a native ELF on the platform running this python script def compile_native(src, name): fname = "native_" + name @@ -84,58 +128,6 @@ def compile_native(src, name): return mod -# Build a accelerator controller ELF -def compile_accelerator(src, name, ext): - name = "npu_" + name - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - libc_inc = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "include") - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, f"{name}.{ext}") - # FIXME: Hardcoded path - #dst_path = os.path.join(tmpdir, "wrapper.so") - dst_path = "/tmp/wrapper.o" - with open(src_path, "w") as f: - f.write(src) - _dump_ir_if_needed([src_path]) - clang_path = _get_llvm_bin_path("clang") - # Compile - subprocess.check_call([ - clang_path, src_path, "-O2", "-c", "-fPIC", f"-I{libc_inc}", "--target=riscv64-unknown-elf", - "-march=rv64imafdc", "-o", dst_path - ]) - - with tempfile.TemporaryDirectory() as tmpdir: - # FIXME: Hardcoded path - #dst_path = os.path.join(tmpdir, f"{name}.so") - dst_path = "/tmp/kernel.so" - libc_lib = os.path.join(_get_libc_root(), "riscv64-unknown-elf", "lib", "rv64imafdc", "lp64d") - libcrt_lib = os.path.join(_get_libc_root(), "lib", "gcc", "riscv64-unknown-elf", "15.0.0", "rv64imafdc", - "lp64d") - libvr_path = _get_vendor_runtime_path() - clang_path = _get_llvm_bin_path("clang") - # Link wrapper, kernel with Tx81 crt and intrinsics(libkcorert.a) - subprocess.check_call([ - clang_path, "-nostdlib", - # FIXME: Hardcoded path - "/tmp/wrapper.o", "/tmp/kernel.o", "-O2", "--target=riscv64-unknown-elf", "-march=rv64imafdc", "-fPIC", - # "-shared", # ELF toolchain doesn't support -shared - f"-L{libvr_path}", f"-L{libc_lib}", f"-L{libcrt_lib}", - # Allow libkcorert symbol overwrite libc symbols, libkcorert - # should be specified before libc - "-Wl,--allow-multiple-definition", "-lvr", # Wrapper API of Tx81 intrinsic - "-lkcorert", # Tx81 intrinsic API - "-lc", "-lm", "-lgcc", "-T", f"{libvr_path}/gcc_tx8_smarth.ld", "-o", dst_path - ]) - - _dump_ir_if_needed([dst_path]) - with open(dst_path, 'rb') as f: - so = f.read() - return so - - # -------------------- Launcher ---------------------------- def _ty_to_cpp(ty): if ty[0] == '*': @@ -160,173 +152,43 @@ def _ty_to_cpp(ty): def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" if ty[0] == '*': return "PyObject*" + if ty == "constexpr": + return "PyObject*" return _ty_to_cpp(ty) def _format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty in ("constexpr", "nvTmaDesc"): + return "O" return { - "PyObject*": "O", "float": "f", "double": "d", "long": "l", "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", "uint64_t": "K", - }[ty] - - -# This function makes a single kernel invoker which wraps all the input args into -# a single input buffer. -def make_kernel_wrapper_v2(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - - return f""" -#include -#include - -// Triton kernel forward declaration, the last 6 arguments are: gridXYZ and xyz -// Using -convert-func-to-llvm=use-bare-ptr-memref-call-conv=true. -void {kernel_name}({arg_decls}, int, int, int, int, int, int); - -// Kernel entry point -// NOTE: Assuming the triton kernel can only take 2 kind of arguments: -// 1. 8 bytes scalar -// 2. Tensor buffer (8 bytes memory address) -// -// The input buffer has the following format: -// +--------------------------------------------------------------------------+ -// | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 4 bytes | 8 bytes | -// | gridX | gridY | gridZ | x | y | z | karg1 | -// +--------------------------------------------------------------------------+ -// | 8 bytes | ... | 8 bytes | -// | karg2 | ... | kargn | -// +-------------------------------+ -void __{kernel_name}(void *args) {{ - void* basePtr = args; - - // Extract the kernel arguments from kernel buffer - int gridX = *((int*)basePtr); - int gridY = *((int*)basePtr+1); - int gridZ = *((int*)basePtr+2); - int x = *((int*)basePtr+3); - int y = *((int*)basePtr+4); - int z = *((int*)basePtr+5); - void* krnArgOffsets = (void*) ((int*)basePtr + 6); - - if (gridX*gridY*gridZ <= 0) - return; - - // Invoke the actual kernel. - {kernel_name}({', '.join([f"(void*) (((uint64_t*)krnArgOffsets)[{i}])" - if ty[0] == "*" else - f"*({_ty_to_cpp(ty)}*)(((uint64_t*)krnArgOffsets)[{i}])" - for i, ty in signature.items()])}, - gridX, gridY, gridZ, x, y, z); -}} -""" - - -def make_kernel_wrapper(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - - return f""" -#include -#include -#include - -// Tx81 target framwork related definition -typedef struct BootParamHead -{{ - uint32_t MaxLen; - uint32_t LdmemLen; - uint32_t InputNum; - uint32_t OutputNum; - uint32_t ParamNum; - uint32_t reserved; - uint64_t CacheMemLen; - uint64_t CacheMemAttr; - uint32_t Datalen; - uint32_t reserved1; - uint64_t DataAddr; -}} D_BootParamHead; - -// Tx81 target framwork related definition -typedef struct BootParamDyninfo -{{ - uint64_t addr; // device - uint64_t size; - uint32_t dtype; - uint32_t dim; - uint32_t shape[6]; -}} D_BootParamDyninfo; - -// Triton kernel forward declaration, the last 6 arguments are: gridXYZ and xyz -void {kernel_name}({arg_decls}, int, int, int, int, int, int); - -// Get the entry point of kernel arg buffer -void* getKernelArgBuffer(void *args) {{ - // Always use the first BootParam to carry the address points to kernel - // arguments buffer - D_BootParamHead *head = (D_BootParamHead *)args; - assert(head->InputNum == 1); - // Decode the first parameter from BootParam as the kernel buffer info. - D_BootParamDyninfo* kernelBuffer = (D_BootParamDyninfo *)((char *)args + - sizeof(D_BootParamHead)); - // Kernel buffer address on device DDR - return (void*) kernelBuffer->addr; -}} - -// Kernel wrapper -void task(void *krnArgBuf, void *krnArgOffsets, - int gridX, int gridY, int gridZ, int x, int y, int z) {{ - - // Invoke the actual kernel by passing in the triton kernel arguments stored - // on device DDR and the other arguments which generated by compiler. - {kernel_name}({', '.join([f"(void*) (krnArgBuf + ((uint64_t*)krnArgOffsets)[{i}])" - if ty[0] == "*" else - f"*({_ty_to_cpp(ty)}*)(krnArgBuf + ((uint64_t*)krnArgOffsets)[{i}])" - for i, ty in signature.items()])}, - gridX, gridY, gridZ, x, y, z); -}} - -// Kernel entry point, name is aligned that specified to TsmLoadKernel -void __kernel_entry(void *args) {{ - void* basePtr = getKernelArgBuffer(args); - - // Extract the kernel arguments from kernel buffer - int krnArgCount = *(int*)basePtr; - int gridX = *((int*)basePtr+1); - int gridY = *((int*)basePtr+2); - int gridZ = *((int*)basePtr+3); - void* krnArgOffsets = (void*) ((int*)basePtr + 4); - void* krnArgBuf = krnArgOffsets + krnArgCount * sizeof(uint64_t*); - - if (gridX*gridY*gridZ <= 0) - return; - - // Cast "function" to the real function type. - for(int x = 0; x < gridX; x++) {{ - for(int y = 0; y < gridY; y++) {{ - for(int z = 0; z < gridZ; z++) {{ - task (krnArgBuf, krnArgOffsets, gridX, gridY, gridZ, x, y, z); - }} - }} - }} -}} -""" + }[_ty_to_cpp(ty)] def make_launcher(constants, signature, kernel_name): # Basic declarations - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") + args_format = ''.join([_format_of(ty) for ty in signature.values()]) format = "iiiOOOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' @@ -334,7 +196,7 @@ def make_launcher(constants, signature, kernel_name): kernel_parameters = ', '.join( f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"tx81_ptr{i}, &ptr_arg{i}" for i, ty in signature.items() - if i not in constants) + if ty != "constexpr") kernel_parameters += ', ' if kernel_parameters else '' return f""" @@ -346,11 +208,12 @@ def make_launcher(constants, signature, kernel_name): #include #include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include +//#include #include #include #include #include +#include #include "hrt_interface.h" #include "hrt_common.h" @@ -370,6 +233,10 @@ def make_launcher(constants, signature, kernel_name): // karg1 offset karg2 offset kargn offset // ... Metadata buffer... | ............ kernel arg buffer .................. +enum DATA_TYPE {{ + SCALAR, + POINT, +}}; // A kernel argument struct KernelArg {{ @@ -379,16 +246,19 @@ def make_launcher(constants, signature, kernel_name): uint64_t scalar; // Scalar data }} data; size_t size; // The size of the kernel argument + int data_type; KernelArg(void *ptr, size_t s) : size(s) {{ data.ptr = ptr; + data_type = POINT; }} - KernelArg(uint64_t v, size_t s) : size(s) {{ + KernelArg(uint64_t v, size_t s) : size(0) {{ data.scalar = v; + data_type = SCALAR; }} -}}; +}}; extern "C" {{ // The kernel arguments includes: @@ -402,6 +272,23 @@ def make_launcher(constants, signature, kernel_name): static std::vector g_tx81_devices; static bool g_runtime_initialized = false; +// FIXME: Hardcoded path +std::string chip_out = "/tmp/chip_out/node0/"; +std::string kernel_file = "/tmp/kernel.so"; +std::string kernel_fun_name = "{kernel_name}"; +uint32_t sharedMemBytes = 0; + +typedef void* Stream_t; + +static uint64_t get_phy_addr(uint64_t logic_addr) {{ + uint32_t card_id; + uint64_t addr; + uint64_t size; + TsmMemGetInfo(logic_addr, card_id, addr, size); + return addr; +}} + + // Initialize Tx81 runtime bool init_tx81_runtime() {{ if (g_runtime_initialized) {{ @@ -422,358 +309,210 @@ def make_launcher(constants, signature, kernel_name): return false; }} + // FIXME: Hardcoded // Set up devices - for simplicity, we're using a 1x1 configuration uint32_t first_phy_id = 0; uint32_t card_x = 1; uint32_t card_y = 1; - if (TsmSetDevice(first_phy_id, card_x, card_y, g_tx81_devices) != RET_SUCCESS) {{ + TsmDevice *dev = new TsmDevice(); + if (TsmSetDevice(&dev, 0, first_phy_id) != RET_SUCCESS) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to set Tx81 devices"); TsmDeInitRuntime(); return false; }} + g_tx81_devices.push_back(dev); + + // FIXME: Hardcoded + TsmModel *new_model = new TsmModel(); + // Create a vector of models + std::vector kmodel_vec = {{new_model}}; + std::string option = "-O2"; + CompileOption compl_option = {{}}; + compl_option.comp_enable = 0; // Use prebuilt binary + compl_option.chip_x = 1; //单卡 + compl_option.chip_y = 1; + compl_option.check_enable = true; + compl_option.enable_kcore_bin = 1; + compl_option.enable_kcore_so = 1; + new_model->case_dir = chip_out; + + for (TsmDevice * dev : g_tx81_devices) {{ + if (TsmCompileMultiGraph(dev, *new_model, option, compl_option) != RET_SUCCESS) {{ + for (uint32_t dev_index = 0; dev_index < g_tx81_devices.size(); ++dev_index) {{ + if (TsmResetDevice(g_tx81_devices[dev_index]) != RET_SUCCESS) {{ + return false; + }} + }} + return false; + }} + }} // Initialize all devices for (auto* dev : g_tx81_devices) {{ - if (TsmInitDevice(dev) != RET_SUCCESS) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to initialize Tx81 device"); - TsmDeInitRuntime(); + if (TsmLaunch(dev, *new_model) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); + TsmReleaseDevice(dev); + TsmResetDevice(dev); + return false; + }} + + if (TsmSetMonitorInfo(dev) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "[Chip id] TsmLaunch failed."); + TsmReleaseDevice(dev); + TsmResetDevice(dev); return false; }} }} + delete new_model; g_runtime_initialized = true; + return true; }} // Clean up Tx81 runtime resources -void cleanup_tx81_runtime() {{ +static PyObject* cleanup_tx81_runtime(PyObject* self, PyObject* args) {{ if (!g_runtime_initialized) {{ - return; + Py_RETURN_NONE; }} for (auto* dev : g_tx81_devices) {{ + if (TsmSetTerminate(dev) != RET_SUCCESS) {{ + Py_RETURN_NONE; + }} // Reset and release each device - TsmResetDevice(dev); TsmReleaseDevice(dev); + TsmResetDevice(dev); + delete dev; }} - g_tx81_devices.clear(); TsmDeInitRuntime(); g_runtime_initialized = false; + Py_RETURN_NONE; }} +TSM_RETCODE argsToDevMemArray(TsmDevice *dev, std::vector &kargs, + std::vector &rtKargs, std::vector &devAddrs) {{ + int count = 0; + for (KernelArg& karg : kargs) {{ + if (karg.data_type == POINT) {{ + TsmDevicePtr dev_buffer; + if (TsmDeviceMalloc(dev, dev_buffer, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceMalloc"); + return RET_ERROR; + }} -static void prepare_input(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) {{ - for (uint32_t i = 0; i < chip_info->input_num; ++i) {{ - chip_info->input_dev_addr.push_back(0); - if (TsmDeviceMalloc(devices[dev_index], chip_info->input_dev_addr[i], - chip_info->input_size[i]) != RET_SUCCESS) {{ - printf("[Chip id %u] Input%d, DeviceMalloc failed!\\n", devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; - }} - - if (TsmMemcpyH2D((TsmDevicePtr)chip_info->input_dev_addr[i], - (void*) chip_info->input_host_addr[i], - chip_info->input_size[i]) != RET_SUCCESS) {{ - printf("[Chip id %u] Input%d, MemcpyH2D failed!\\n", devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; - }} - }} -}} + if (TsmMemcpyH2D(dev_buffer, karg.data.ptr, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); + return RET_ERROR; + }} + devAddrs.push_back(dev_buffer); + // FIXME: rank + rtKargs.push_back(1); + rtKargs.push_back(get_phy_addr(dev_buffer)); -static void prepare_output(std::vector devices, uint32_t dev_index, - std::shared_ptr chip_info) {{ - for (size_t i = 0; i < chip_info->output_num; ++i) {{ - chip_info->output_dev_addr.push_back(0); - printf("[Chip id %u] output[%lu] data(size: %lu)\\n", - devices[dev_index]->chip_id, i, chip_info->output_size[i]); - - if (TsmDeviceMalloc(devices[dev_index], chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) {{ - printf("[Chip id %u] output[%lu], DeviceMalloc failed!\\n", - devices[dev_index]->chip_id, i); - TsmResetDevice(devices[dev_index]); - return; + count++; + }} + else {{ + rtKargs.push_back(karg.data.scalar); + }} }} - }} + return RET_SUCCESS; }} -TSM_RETCODE kernel_result_process(std::vector devices, uint32_t dev_index, - std::shared_ptr hostboot, - std::shared_ptr chip_info, - TsmDevicePtr bootpm_dev, std::string case_dir) {{ - for (size_t i = 0; i < chip_info->output_num; ++i) {{ - // 动态shape, 需要处理真实的output size - if (TsmMemcpyD2H(hostboot->get_bootpmbuffer(), bootpm_dev, - hostboot->get_maxlen()) != RET_SUCCESS) {{ - return RET_ERROR; - }} - - auto out_tensor = hostboot->get_dev_output_tensor_after_run(i); - chip_info->output[i]->dim = out_tensor->dim; - std::memcpy(chip_info->output[i]->shape, out_tensor->shape, sizeof(out_tensor->shape)); - chip_info->output_size[i] = hrt_get_dtype_size((DTYPE)chip_info->output[i]->dtype); - for (uint32_t j = 0; j < out_tensor->dim; ++j) {{ - if (out_tensor->shape[j] > 0) {{ - chip_info->output_size[i] *= out_tensor->shape[j]; - }} - }} - - TsmHostPtr output_host_addr = (TsmHostPtr)malloc(chip_info->output_size[i]); - if (chip_info->output_size[i] > 0) {{ - if (TsmMemcpyD2H((void*)output_host_addr, chip_info->output_dev_addr[i], - chip_info->output_size[i]) != RET_SUCCESS) {{ - return RET_ERROR; - }} - }} - - printf("[Chip id %u] output_dev_addr=%ld\\n", devices[dev_index]->chip_id, - chip_info->output_dev_addr[i]); +TSM_RETCODE devMemArrayToArgs(TsmDevice *dev, std::vector &kargs, + std::vector &devAddrs) {{ - // TODO: Processing output -#if 0 - std::string file_path = case_dir + "/chip" + std::to_string(dev_index) + - "/agent/data/out" + std::to_string(i) + "_riscv.bin"; - saveDataToFile(file_path, output_host_addr, chip_info->output_size[i]); -#endif - - if (output_host_addr != 0) {{ - free((void *)output_host_addr); + int count = 0; + for (KernelArg& karg : kargs) {{ + if (karg.data_type == POINT) {{ + uint64_t dev_buffer = devAddrs[count++]; + if (TsmMemcpyD2H(karg.data.ptr, dev_buffer, karg.size) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmMemcpyH2D"); + return RET_ERROR; + }} + }} }} - }} - return RET_SUCCESS; + return RET_SUCCESS; }} -TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) {{ - if (bootpm_dev != 0) {{ - printf("[Chip id %u] bootpm dev addr: 0x%lx \\n", chip_id, bootpm_dev); - if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) {{ - return RET_ERROR; +TSM_RETCODE devMemFree(TsmDevice *dev, std::vector &devAddrs) {{ + for (uint64_t dev_buffer : devAddrs) {{ + if (TsmDeviceFree(dev_buffer) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmDeviceFree"); + return RET_ERROR; + }} }} - bootpm_dev = 0; - }} - - return RET_SUCCESS; + return RET_SUCCESS; }} -static void setHostBoot(std::shared_ptr &chip_info, - std::shared_ptr &hostboot) {{ - if (chip_info == nullptr) {{ - printf("chip_info is null.\\n"); - return; - }} - - if (hostboot == nullptr) {{ - printf("hostboot is null.\\n"); - return; - }} - - for (size_t i = 0; i < chip_info->input_dev_addr.size(); ++i) {{ - hostboot->set_dev_input(i, chip_info->input_dev_addr[i], chip_info->input_size[i]); - hostboot->set_dev_input_tensor(i, chip_info->input[i]); - }} - - for (size_t i = 0; i < chip_info->output_dev_addr.size(); ++i) {{ - hostboot->set_dev_output(i, chip_info->output_dev_addr[i], chip_info->output_size[i]); - }} - - for (size_t i = 0; i < chip_info->param_num; ++i) {{ - hostboot->set_dev_param(i, chip_info->param_dev_addr[i], chip_info->param_size[i]); - }} - - return; +TSM_RETCODE freeMemPerStep(uint32_t chip_id, TsmDevicePtr &bootpm_dev) {{ + if (bootpm_dev != 0) {{ + if (TsmDeviceFree(bootpm_dev) != RET_SUCCESS) {{ + return RET_ERROR; + }} + bootpm_dev = 0; + }} + return RET_SUCCESS; }} - -static void _launch(int gridX, int gridY, int gridZ, std::vector &kargs) {{ - std::vector devices; +static void _launch(int gridX, int gridY, int gridZ, std::vector kargs) {{ + std::vector &devices = g_tx81_devices; if (gridX*gridY*gridZ <= 0) {{ return; // No work to do }} - TsmModel *new_model = new TsmModel(); - - // Create a vector of models - std::vector kmodel_vec = {{new_model}}; - std::string option = "-O2"; - CompileOption compl_option = {{}}; - compl_option.comp_enable = 0; // Use prebuilt binary - compl_option.chip_x = 1; //单卡 - compl_option.chip_y = 1; - compl_option.check_enable = true; - compl_option.enable_kcore_bin = 1; - compl_option.enable_kcore_so = 1; - // FIXME: Hardcoded path - new_model->case_dir = "/tmp/kernel.so"; - - printf("====> Calling TsmCompileMultiGraph\\n"); -#if 0 - if (TsmCompileMultiGraph(devices, *new_model, option, compl_option) != RET_SUCCESS) {{ - for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ - if (TsmResetDevice(devices[dev_index]) != RET_SUCCESS) {{ - printf("[Chip id %u] tx_engine: tx_reset, failed!\\n", dev_index); - }} else {{ - printf("[Chip id %u] tx_engine: tx_reset, success!\\n", dev_index); - }} - }} - printf("TsmCompile failed.\\n"); + // TODO::mv + uint64_t kernel_len = 0; + uint8_t* kernel_ptr = read_file_data(kernel_file, kernel_len); + if (kernel_ptr == nullptr) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to read kernel so"); + TsmDeInitRuntime(); return; }} -#endif - // Calculate the total size of kernel arguments buffer - uint64_t kernel_buffer_size = 0; - for (auto karg : kargs) - kernel_buffer_size += karg.size; - - // Calcuate The kernel argument buffer header size - // 4 bytes header + n * kernel argument metadata + 3 * sizeof(gridXYZ) - uint64_t kernel_meta_buf_size = sizeof(uint64_t*) * kargs.size() + 4 + 12; - kernel_buffer_size += kernel_meta_buf_size; - - // We use input_num = 1 to set the whole kernel arguments buffer as a single - // input - uint32_t input_num = 1; - uint32_t output_num = 0; - uint32_t param_num = 0; - - // Create boot parameter - std::shared_ptr hostboot = std::make_shared(input_num, output_num, param_num); - - // Create chip common info - std::shared_ptr chip_info = std::make_shared(); - chip_info->input_num = input_num; - chip_info->output_num = output_num; - chip_info->param_num = param_num; - chip_info->imm_size = 0; // Cache size - - // Prepare input/output sizes and addresses - chip_info->input_size.resize(input_num); - chip_info->input_host_addr.resize(input_num); - chip_info->input_dev_addr.resize(input_num); - chip_info->output_size.resize(output_num); - chip_info->output_host_addr.resize(output_num); - chip_info->output_dev_addr.resize(output_num); - - // Prepare whole kernel buffer info - chip_info->input.push_back(std::make_shared()); - chip_info->input[0]->dim = 1; - chip_info->input[0]->dtype = FMT_FP32; // Default to float - chip_info->input[0]->shape[0] = 1; // Default shape - chip_info->input_size[0] = kernel_buffer_size; - chip_info->input_host_addr = std::vector{{(uint64_t) 0x0}}; // prepare data/ load kernel/run/unload kernel/get out data/release memory for (uint32_t dev_index = 0; dev_index < devices.size(); ++dev_index) {{ - // input prepare - prepare_input(devices, dev_index, chip_info); - // output prepare - prepare_output(devices, dev_index, chip_info); - - uint32_t chip_id = devices[dev_index]->chip_id; - TsmSetMonitorInfo(devices[dev_index]); - - // load kernel - char module_symbol[] = "__kernel_entry"; - TsmLoadKernel(devices[dev_index], kmodel_vec, module_symbol); - printf("TsmLoadKernel finish!...\\n"); - - printf("[Chip id %u] Set boot-params...\\n", chip_id); - size_t dyn_mod_size = sizeof(DynMods) + sizeof(DynModule); - TsmDevicePtr dev_dyn_mods_ptr; - if (TsmDeviceMalloc(devices[dev_index], dev_dyn_mods_ptr, dyn_mod_size) != RET_SUCCESS) - return; - // Allocate the device memory for all kernel arguments - TsmDevicePtr dev_kernel_buffer; - if (TsmDeviceMalloc(devices[dev_index], dev_kernel_buffer, kernel_buffer_size) != RET_SUCCESS) - return; + std::vector devAddrs; + std::vector rtKargs; - // Kernel meta data and argument buffer - int dev_karg_ptr = dev_kernel_buffer + kernel_meta_buf_size; - - // Kernel arguments address - uint64_t arg_metadata[kargs.size()]; - - // Copy kernel arguments to device DDR (immediately after the metadata) - int i = 0; - uint64_t offset = 0; - for (auto karg : kargs) {{ - if (TsmMemcpyH2D(dev_karg_ptr, karg.data.ptr, karg.size) != RET_SUCCESS) - return; - - // Calculate the offset of each kernel arg's buffer - arg_metadata[i++] = offset; - - // Shift the offset and pointer for next kernel argument. - offset += karg.size; - dev_karg_ptr += karg.size; - }} - - // Create the metadata buffer - uint32_t* metadata = (uint32_t*) malloc(kernel_meta_buf_size); - metadata[0] = (int) kargs.size(); - metadata[1] = gridX; - metadata[2] = gridY; - metadata[3] = gridZ; - memcpy(metadata+20, arg_metadata, kernel_meta_buf_size - 16); - - // Copy kernel metadata to device DDR - if (TsmMemcpyH2D(dev_kernel_buffer, metadata, kernel_meta_buf_size) != RET_SUCCESS) - return; - - setHostBoot(chip_info, hostboot); - set_multi_graph(kmodel_vec[0], hostboot, dev_dyn_mods_ptr, 0, dev_kernel_buffer); - - TsmDevicePtr bootpm_dev; - if (TsmDeviceMalloc(devices[dev_index], bootpm_dev, hostboot->get_maxlen()) != RET_SUCCESS) - return; - - if (TsmMemcpyH2D(bootpm_dev, hostboot->get_bootpmbuffer(), hostboot->get_maxlen()) != RET_SUCCESS) - return; - - if (TsmRun(devices[dev_index], bootpm_dev) != RET_SUCCESS) {{ - printf("TsmRun bootpm_dev failed.\\n"); + if (argsToDevMemArray(devices[dev_index], kargs, rtKargs, devAddrs) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to argsToDevMemArray"); + TsmDeInitRuntime(); return; }} - TsmUnloadKernel(devices[dev_index], kmodel_vec); - - // Process kernel output data - printf("[Chip id %u] Copy output from device...\\n", chip_id); - if (kernel_result_process(devices, dev_index, hostboot, chip_info, bootpm_dev, new_model->case_dir) != RET_SUCCESS) {{ - printf("free dev memory failed.\\n"); - return; + rtKargs.push_back(gridX); + rtKargs.push_back(gridY); + rtKargs.push_back(gridZ); + rtKargs.push_back(0); + rtKargs.push_back(0); + rtKargs.push_back(0); + + // TSM_RETCODE TsmKernelLaunch(TsmDevice *dev, const char *func_name, void *kernel_ptr, uint32_t kernel_len, + // uint32_t grid_dim, uint32_t block_dim, void *args, uint32_t args_len); + if (TsmKernelLaunch(devices[dev_index], kernel_fun_name.c_str(), (void*)kernel_ptr, kernel_len, + gridX, 1, (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t)) != RET_SUCCESS){{ + PyErr_SetString(PyExc_RuntimeError, "Failed to TsmKernelLaunch"); + TsmDeInitRuntime(); }} - - if (freeMemPerStep(chip_id, bootpm_dev) != RET_SUCCESS) {{ - printf("free dev memory failed.\\n"); + if (devMemArrayToArgs(devices[dev_index], kargs, devAddrs) != RET_SUCCESS) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to devMemArrayToArgs"); + TsmDeInitRuntime(); return; }} - if (TsmDeviceFree(dev_kernel_buffer) != RET_SUCCESS) {{ - printf("free dev_kernel_param_ptr failed.\\n"); - return; - }} + // getchar(); - if (TsmDeviceFree(dev_dyn_mods_ptr) != RET_SUCCESS) {{ - printf("free dev_dyn_mods_ptr failed.\\n"); - return; - }} + // TsmUnloadKernel(devices[dev_index], kmodel_vec); - printf("[dev_index %u] Set Terminal Info...\\n", dev_index); - if (TsmSetTerminate(devices[dev_index]) != RET_SUCCESS) {{ - printf("TsmSetTerminate failed.\\n"); + if (devMemFree(devices[dev_index], devAddrs) != RET_SUCCESS) {{ return; }} }} - - // Clean up the model - delete new_model; }} // Structure to represent a device pointer @@ -794,6 +533,21 @@ def make_launcher(constants, signature, kernel_name): return contiguous_tensor.data_ptr(); }} +static PyObject* init_runtime(PyObject* self, PyObject* args) {{ + const char* _chip_out; + if (!PyArg_ParseTuple(args, "s", &_chip_out)) {{ + return NULL; + }} + chip_out = _chip_out; + + // Initialize Tx81 runtime during module import + if (!init_tx81_runtime()) {{ + return NULL; + }} + + return Py_None; +}} + // Python module launch function static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; @@ -836,9 +590,10 @@ def make_launcher(constants, signature, kernel_name): // Construct a data kernel arguments list data structure std::vector kargs; + //{' '.join([f"kargs.emplace_back(_arg{i}, PyObject_Size(_arg{i})*4);" if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} {' '.join([f"kargs.emplace_back(extractTensor(_arg{i}), getTensorStorageSize(_arg{i}));" if ty[0]=="*" else f"kargs.emplace_back(_arg{i}, sizeof(_arg{i}));" - for i, ty in signature.items()])} + for i, ty in signature.items() if ty != "constexpr"])} // Launch the kernel _launch(gridX, gridY, gridZ, kargs); @@ -863,6 +618,7 @@ def make_launcher(constants, signature, kernel_name): // Python module method definitions static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{"init_runtime", init_runtime, METH_VARARGS, "Init runtime with chip_out dir"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; @@ -891,13 +647,6 @@ def make_launcher(constants, signature, kernel_name): PyModule_AddFunctions(m, ModuleMethods); -#if 0 - // Initialize Tx81 runtime during module import - if (!init_tx81_runtime()) {{ - Py_DECREF(m); - return NULL; - }} - // Register an atexit handler to cleanup Tx81 runtime PyObject* atexit_module = PyImport_ImportModule("atexit"); if (atexit_module) {{ @@ -909,30 +658,29 @@ def make_launcher(constants, signature, kernel_name): }} Py_DECREF(atexit_module); }} -#endif return m; }} """ -class CrossUtils(object): +class TXDAUtils(object): def __new__(cls): if not hasattr(cls, "instance"): - cls.instance = super(CrossUtils, cls).__new__(cls) + cls.instance = super(TXDAUtils, cls).__new__(cls) return cls.instance def __init__(self): src = Path(os.path.join(dirname, "driver.cpp")).read_text() mod = compile_native(src, "tx81_utils") - # NOTE: The triton compiler.py framework requires these 2 interface. + # # NOTE: The triton compiler.py framework requires these 2 interface. self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties # Launch cross compiled runtime program on controller -class CrossLauncher(object): +class TXDALauncher(object): def __init__(self, src, metadata): constants = src.constants if hasattr(src, "constants") else dict() @@ -940,17 +688,13 @@ def __init__(self, src, metadata): constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} - # Compiler kernel wrapper source code - # NOTE: Replace this make_kernel_wrapper to v2 version by if you want - # to call the triton kernel with single input buffer and with a '__' - # prefixed name. - wrapper_src = make_kernel_wrapper(constants, signature, src.fn.__name__) - krn = compile_accelerator(wrapper_src, src.fn.__name__, "c") - # Compiler runtime kernel launcher source code launcher_src = make_launcher(constants, signature, src.fn.__name__) mod = compile_native(launcher_src, "__triton_launcher") self.launch = mod.launch + chip_out = os.path.join(_get_tx8_path("chip_out"), "node0") + chip_out = chip_out + os.sep + mod.init_runtime(chip_out) def __call__(self, *args, **kwargs): # args: 0: gridX, 1: gridY, 2: gridZ, @@ -961,33 +705,41 @@ def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CrossDriver(GPUDriver): +class TXDADriver(GPUDriver): def __init__(self): super().__init__() - self.utils = CrossUtils() - self.launcher_cls = CrossLauncher + self.utils = TXDAUtils() + self.launcher_cls = TXDALauncher # Needs to overwrite GPUDriver base methods - self.get_current_device = self.get_npu_device - self.set_current_device = self.set_npu_device - self.get_current_stream = self.get_npu_stream + self.get_current_stream = torch.txda.current_stream + self.get_current_device = torch.txda.current_device + self.set_current_device = torch.txda.set_device + atexit.register(torch.txda.cleanup_device) @staticmethod def is_active(): - return True - - def get_npu_device(self): - return "cpu" - - def set_npu_device(self, device): - # CPU doesn't have a device to set - assert device == "cpu" - return - - def get_npu_stream(self, device): - return None + try: + #import torch + #return torch.txda.is_available() + return True + except ImportError: + return False def get_current_target(self): capability = 1 warp_size = 16 - return GPUTarget("cpu", capability, warp_size) + return GPUTarget("txda", capability, warp_size) + + def get_active_torch_device(self): + # import torch + # torch.txda.init_device() + return torch.device("txda", self.get_current_device()) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_device_interface(self): + import torch + return torch.txda diff --git a/third_party/tsingmicro/backend/name.conf b/third_party/tsingmicro/backend/name.conf new file mode 100644 index 000000000..1340763be --- /dev/null +++ b/third_party/tsingmicro/backend/name.conf @@ -0,0 +1 @@ +tsingmicro diff --git a/third_party/tsingmicro/backend/txda_device.cpp b/third_party/tsingmicro/backend/txda_device.cpp new file mode 100644 index 000000000..ac46d67ac --- /dev/null +++ b/third_party/tsingmicro/backend/txda_device.cpp @@ -0,0 +1,180 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace detail { + +C10_REGISTER_GUARD_IMPL( + PrivateUse1, c10::impl::NoOpDeviceGuardImpl); + +} +} // namespace at + +struct TXDADeviceAllocator final : at::Allocator { + TXDADeviceAllocator() {} + + at::DataPtr allocate(size_t nbytes) override { + void *data = c10::alloc_cpu(nbytes); + return {data, nullptr, &ReportAndDelete, + at::Device(at::DeviceType::PrivateUse1, 0)}; + } + + static void ReportAndDelete(void *ptr) { + if (!ptr) { + return; + } + // TsmDeviceFree((uint64_t)ptr) + c10::free_cpu(ptr); + } + + at::DeleterFnPtr raw_deleter() const override { return &ReportAndDelete; } + void copy_data(void *dest, const void *src, std::size_t count) const final { + default_copy_data(dest, src, count); + } +}; + +// register device allocator +static TXDADeviceAllocator global_txda_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_txda_alloc); + +// to.Device +at::Tensor txda_to_device(const at::Tensor &self, at::Device device, + at::ScalarType dtype, bool non_blocking, bool copy, + c10::optional memory_format) { + // TsmMemcpyH2D(); + + TORCH_CHECK(self.is_cpu() || + self.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + // Some dummy asserts for the basic use case: inputs are the same size / + // dtype, all contiguous. + TORCH_CHECK(self.scalar_type() == dtype); + TORCH_CHECK(self.is_contiguous()); + + if (device != at::DeviceType::CPU) { + return at::empty(self.sizes(), self.options()); + } + + auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, + false, memory_format); + memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes()); + return out; +} + +// _copy_from +at::Tensor txda__copy_from(const at::Tensor &self, const at::Tensor &dst, + bool non_blocking) { + // TsmMemcpyD2H(); + + TORCH_CHECK(self.is_cpu() || + self.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + TORCH_CHECK(dst.is_cpu() || + dst.device().type() == c10::DeviceType::PrivateUse1, + "only support data transfer between cpu and txda"); + + // Some dummy asserts for the basic use case: inputs are the same size / + // dtype, all contiguous. + TORCH_CHECK(self.sizes() == dst.sizes()); + TORCH_CHECK(self.scalar_type() == dst.scalar_type()); + TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); + + std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), + self.storage().nbytes()); + return dst; +} + +at::Tensor txda_empty_memory_format( + at::IntArrayRef size, std::optional dtype, + std::optional layout, std::optional device, + std::optional pin_memory, + std::optional memory_format) { + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic(size, &global_txda_alloc, private_use_ks, + c10::dtype_or_default(dtype), memory_format); +} + +at::Tensor txda_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + auto dtype = c10::dtype_or_default(dtype_opt); + return at::detail::empty_strided_generic(size, stride, &global_txda_alloc, + private_use_ks, dtype); +} + +at::Tensor txda_as_strided(const at::Tensor &input, at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + return at::cpu::as_strided(input, size, stride, storage_offset); +} + +at::Tensor &txda_fill__scalar(at::Tensor &self, const at::Scalar &value) { + TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, + "only support txda"); + TORCH_CHECK(self.is_contiguous()); + TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float); + + auto _data = static_cast(self.mutable_data_ptr()); + for (size_t idx = 0; idx < self.numel(); idx++) { + _data[idx] = value.toFloat(); + } + + return self; +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("to.Device", &txda_to_device); + m.impl("fill_.Scalar", &txda_fill__scalar); + m.impl("_copy_from", &txda__copy_from); + m.impl("empty.memory_format", &txda_empty_memory_format); + m.impl("empty_strided", &txda_empty_strided); + m.impl("as_strided", &txda_as_strided); +} + +bool init_device() { + // return init_txda_runtime(); + return true; +} + +bool cleanup_device() { + // cleanup_txda_runtime(); + return true; +} + +int current_device() { return 0; } + +int current_stream(int id) { return 0; } + +void set_device(int id) {} + +c10::Device get_txda_device() { + return c10::Device(c10::DeviceType::PrivateUse1, 0); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("current_device", ¤t_device, "get current tx device"); + m.def("current_stream", ¤t_stream, "get current tx stream"); + m.def("set_device", &set_device, "set tx device"); + m.def("get_txda_device", &get_txda_device, "get tx device"); + m.def("init_device", &init_device, "initialize tx device"); + m.def("cleanup_device", &cleanup_device, "cleanup tx device"); +} diff --git a/third_party/tsingmicro/bin/CMakeLists.txt b/third_party/tsingmicro/bin/CMakeLists.txt new file mode 100644 index 000000000..38c4ea7b6 --- /dev/null +++ b/third_party/tsingmicro/bin/CMakeLists.txt @@ -0,0 +1,88 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(tsingmicro-opt tsingmicro-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(tsingmicro-opt) +target_link_libraries(tsingmicro-opt PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + ZTCAnalysis + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-opt) + +add_llvm_executable(tsingmicro-reduce tsingmicro-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(tsingmicro-reduce) + +llvm_update_compile_flags(tsingmicro-reduce) +target_link_libraries(tsingmicro-reduce PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-reduce) + +add_llvm_executable(tsingmicro-lsp tsingmicro-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(tsingmicro-lsp) + +llvm_update_compile_flags(tsingmicro-lsp) +target_link_libraries(tsingmicro-lsp PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(tsingmicro-lsp) + + +add_llvm_executable(tsingmicro-llvm-opt + tsingmicro-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(tsingmicro-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(tsingmicro-llvm-opt) diff --git a/third_party/tsingmicro/bin/RegisterTritonDialects.h b/third_party/tsingmicro/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..95cb7fbd5 --- /dev/null +++ b/third_party/tsingmicro/bin/RegisterTritonDialects.h @@ -0,0 +1,181 @@ +#pragma once +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "amd/include/TritonAMDGPUTransforms/Passes.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +// Below headers will allow registration to ROCm passes +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "TritonAMDGPUTransforms/TritonGPUConversion.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "nvidia/include/NVGPUToLLVM/Passes.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" +#include "triton-shared/Conversion/TritonToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" +#include "magic-kernel/Conversion/LinalgToMK/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h" + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" + +namespace mlir { +namespace test { +void registerTestAliasPass(); +void registerTestAlignmentPass(); +void registerTestAllocationPass(); +void registerTestMembarPass(); +void registerTestTritonAMDGPURangeAnalysis(); +} // namespace test +} // namespace mlir + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::registerTritonPasses(); + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::registerLinalgPasses(); + mlir::registerTritonNvidiaGPUPasses(); + mlir::test::registerTestAliasPass(); + mlir::test::registerTestAlignmentPass(); + mlir::test::registerTestAllocationPass(); + mlir::test::registerTestMembarPass(); + mlir::test::registerTestTritonAMDGPURangeAnalysis(); + mlir::triton::registerTritonToLinalgPass(); + mlir::triton::registerTritonToStructuredPass(); + mlir::triton::registerTritonArithToLinalgPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerStructuredToMemrefPasses(); + mlir::triton::registerTritonToCoreDialectsPass(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::gpu::registerAllocateSharedMemoryPass(); + mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); + mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); + mlir::triton::registerConvertWarpSpecializeToLLVM(); + mlir::triton::registerConvertTritonGPUToLLVMPass(); + mlir::triton::registerConvertNVGPUToLLVMPass(); + mlir::registerLLVMDIScope(); + + // Core dialects to MK layer conversion passes + mlir::triton::registerTx81MemrefToLLVMPass(); + mlir::triton::registerLinalgToMKPass(); + mlir::triton::registerCoreDialectsToMKPass(); + + // TsingMicro specific conversion passes + mlir::triton::registerMKToTx81Pass(); + mlir::triton::registerTx81ToLLVMPass(); + mlir::triton::registerKernelArgBufferPass(); + + // TritonAMDGPUToLLVM passes + mlir::triton::registerConvertTritonAMDGPUToLLVM(); + mlir::triton::registerConvertBuiltinFuncToLLVM(); + mlir::triton::registerDecomposeUnsupportedAMDConversions(); + mlir::triton::registerOptimizeAMDLDSUsage(); + + // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUAccelerateMatmul(); + mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUHoistLayoutConversions(); + mlir::registerTritonAMDGPUReorderInstructions(); + mlir::registerTritonAMDGPUBlockPingpong(); + mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + + // FIXME: May not need all of these + // mlir::registerAllDialects(registry); + // Register all external models. + // affine::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + mlir::arith::registerShardingInterfaceExternalModels(registry); + mlir::arith::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + mlir::builtin::registerCastOpInterfaceExternalModels(registry); + mlir::cf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::gpu::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::LLVM::registerInlinerInterface(registry); + mlir::NVVM::registerInlinerInterface(registry); + mlir::linalg::registerAllDialectInterfaceImplementations(registry); + mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + mlir::memref::registerAllocationOpInterfaceExternalModels(registry); + mlir::memref::registerBufferViewFlowOpInterfaceExternalModels(registry); + mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + mlir::memref::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::memref::registerMemorySlotExternalModels(registry); + mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::scf::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::shape::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerFindPayloadReplacementOpInterfaceExternalModels( + registry); + mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry); + mlir::tensor::registerSubsetOpInterfaceExternalModels(registry); + mlir::tensor::registerTilingInterfaceExternalModels(registry); + mlir::tensor::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::vector::registerBufferizableOpInterfaceExternalModels(registry); + mlir::vector::registerSubsetOpInterfaceExternalModels(registry); + mlir::vector::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + // This is need for the Bufferization pass(one-shot bufferization) + mlir::registerAllExtensions(registry); + mlir::mk::registerBufferizableOpInterfaceExternalModels(registry); + + registry.insert(); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp b/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/tsingmicro/bin/tsingmicro-lsp.cpp b/third_party/tsingmicro/bin/tsingmicro-lsp.cpp new file mode 100644 index 000000000..f95036dc6 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-lsp.cpp @@ -0,0 +1,10 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-opt.cpp b/third_party/tsingmicro/bin/tsingmicro-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-reduce.cpp b/third_party/tsingmicro/bin/tsingmicro-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp b/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp new file mode 100644 index 000000000..cc121b3e1 --- /dev/null +++ b/third_party/tsingmicro/bin/tsingmicro-tensor-layout.cpp @@ -0,0 +1,232 @@ +#include "RegisterTritonDialects.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/MLIRContext.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace mlir; + +// A CLI tool to print the layout of a tensor. +// +// clang-format off +// Example usage: +// +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view +// +// An input file usually looks like: +// ''' +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// ''' +// clang-format on + +//===--------------------------------------------------------------------===// +// CLI options +//===--------------------------------------------------------------------===// + +cl::OptionCategory PrinterCategory("Available Print Options", + "Options for the tensor layout printing."); + +static cl::opt InputFile( + "i", cl::desc("File that contains the tensor data layout attributes"), + cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory)); + +static cl::opt + OutputFile("o", cl::desc("Output file to write the layout into"), + cl::init(""), cl::value_desc("filename"), + cl::cat(PrinterCategory)); + +static cl::opt + DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), + cl::value_desc("layout-string"), cl::init(""), + cl::cat(PrinterCategory)); + +static cl::list + AliasName("alias-names", + cl::desc("A list of alias names (separated by comma) of the " + "layout attributes in the input file"), + cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, + cl::ZeroOrMore, cl::cat(PrinterCategory)); + +static cl::opt UseHWPointOfView( + "use-hw-view", + llvm::cl::desc( + "Print the layout in hardware point of view. This means the output is " + "from the warp's perspective. Otherwise, the output is from the " + "tensor's perspective (e.g., each element maps to xxx thread)."), + cl::init(false), cl::cat(PrinterCategory)); + +static cl::opt TensorStr( + "t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), + cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory)); + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + // DistributedEncodingTrait and SharedEncodingTrait implements the + // toLinearLayout interface. + mlir::Attribute layout = tensorType.getEncoding(); + if (isa(layout)) { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); +} + +LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, + ArrayRef names, + TensorType tensorTy, raw_string_ostream &ss) { + if (filename.empty()) + return success(); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + ParserConfig config(context); + auto asmState = AsmParserState(); + + Block parsedIR; + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { + llvm::errs() << "Fail to parse the input file: " << filename << "\n"; + return failure(); + } + + auto printLambda = [&](StringRef name, mlir::Attribute attr) { + ss << "Print layout attribute: #" << name << " = " << attr << "\n"; + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), attr); + + return layoutPrint(rankedTensorTy, ss); + }; + + if (names.empty()) + // If no alias name is given, we print all layout attributes in the file. + for (const auto &def : asmState.getAttributeAliasDefs()) { + if (failed(printLambda(def.name, def.value))) + return failure(); + } + else { + // Print the layout attributes with the given alias names. + for (const auto &alias : names) { + auto def = asmState.getAttributeAliasDef(alias); + if (!def) { + llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; + return failure(); + } + + if (failed(printLambda(alias, def->value))) + return failure(); + + ss << "\n"; + } + } + + return success(); +} + +LogicalResult printLayoutFromString(MLIRContext *context, + StringRef layoutAttrStr, + TensorType tensorTy, + raw_string_ostream &ss) { + if (layoutAttrStr.empty()) + return success(); + + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); + if (!layout) { + llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; + return failure(); + } + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + + ss << "Print layout attribute: " << layout << "\n"; + + return layoutPrint(rankedTensorTy, ss); +} + +//===--------------------------------------------------------------------===// +// Main entry point +//===--------------------------------------------------------------------===// + +int main(int argc, char **argv) { + cl::HideUnrelatedOptions(PrinterCategory); + cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); + + DialectRegistry registry; + registerTritonDialects(registry); + + MLIRContext ctx(registry); + ctx.loadAllAvailableDialects(); + + if (TensorStr.empty()) { + llvm::errs() << "Must specify the tensor type argument\n"; + return 1; + } + + mlir::Type parsedTy = parseType(TensorStr, &ctx); + if (!parsedTy) { + llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr + << "\n"; + return 1; + } + + TensorType tensorType = dyn_cast(parsedTy); + if (!tensorType) { + llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; + return 1; + } + + std::string storage; + raw_string_ostream ss(storage); + + if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) + return 1; + + if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) + return 1; + + if (OutputFile.empty()) { + llvm::outs() << ss.str(); + } else { + std::error_code ec; + llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); + if (ec) { + llvm::errs() << "Error: " << ec.message() << " : unable to open " + << OutputFile << " for output\n"; + return 1; + } + outFs << ss.str(); + outFs.close(); + } + + return 0; +} diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt index fa4c0bbb8..f294fab8c 100644 --- a/third_party/tsingmicro/crt/CMakeLists.txt +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -1,5 +1,6 @@ cmake_minimum_required(VERSION 3.18) +set(TARGET Tx81) # Set TARGET from environment variable if(NOT DEFINED TARGET) if(DEFINED ENV{CRT_TARGET}) @@ -9,11 +10,11 @@ if(NOT DEFINED TARGET) endif() endif() -if(NOT DEFINED LIB_C_ROOT) - if(DEFINED ENV{LIB_C_ROOT}) - set(LIB_C_ROOT $ENV{LIB_C_ROOT}) +if(NOT DEFINED XUANTIE_NAME) + if(DEFINED ENV{XUANTIE_NAME}) + set(XUANTIE_NAME $ENV{XUANTIE_NAME}) else() - message(FATAL_ERROR "LIB_C_ROOT environment variable is not defined") + message(FATAL_ERROR "XUANTIE_NAME environment variable is not defined") endif() endif() @@ -26,6 +27,14 @@ if(NOT DEFINED LLVM_SYSPATH) endif() endif() +if(NOT DEFINED TX8_HOME) + if(DEFINED ENV{TX8_HOME}) + set(TX8_HOME $ENV{TX8_HOME}) + else() + message(FATAL_ERROR "TX8_HOME environment variable is not defined") + endif() +endif() + # Project name and version project(VendorRuntime LANGUAGES CXX C) @@ -39,7 +48,8 @@ set(CMAKE_CXX_COMPILER ${LLVM_SYSPATH}/bin/clang++) # Define standard include directories include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/${TARGET}) -include_directories(${LIB_C_ROOT}/include) +include_directories(${TX8_HOME}/include) +include_directories(${TX8_HOME}/${XUANTIE_NAME}/riscv64-unknown-elf/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}) # Set build type default @@ -68,12 +78,6 @@ add_library(${VENDOR_RUNTIME_LIB} STATIC ${VENDOR_SOURCES}) target_compile_options(${VENDOR_RUNTIME_LIB} PRIVATE ${RISCV_COMPILE_OPTIONS}) target_link_options(${VENDOR_RUNTIME_LIB} PRIVATE --target=${RISCV_TRIPLE}) -# Set properties for the library -set_target_properties(${VENDOR_RUNTIME_LIB} PROPERTIES - POSITION_INDEPENDENT_CODE ON - LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib -) - # Setup compiler and environment for RISC-V compilation if(CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") # Use the existing Clang installation with target triple @@ -99,15 +103,3 @@ else() COMMAND ${CMAKE_COMMAND} -E echo "Building ${VENDOR_RUNTIME_LIB} for RISC-V target" ) endif() - -# Install targets -install(TARGETS ${VENDOR_RUNTIME_LIB} - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib - RUNTIME DESTINATION bin -) - -# Install headers (optional) -file(GLOB_RECURSE VENDOR_HEADERS Target/lib/${TARGET}/*.h) -install(FILES ${VENDOR_HEADERS} DESTINATION include/${TARGET}) -install(FILES Target/lib/${TARGET}/libkcorert.a DESTINATION lib/${TARGET}) diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h deleted file mode 100644 index 306624447..000000000 --- a/third_party/tsingmicro/crt/include/Tx81/instr_adapter.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef INSTR_ADAPTER_H -#define INSTR_ADAPTER_H -#include -#include -#include -#include - -// #include "common_base.h" -#include "instr_adapter_plat.h" -#include "instr_def.h" - -#ifndef USING_RISCV -#define __CHECK_INSTR__ -#endif -// #define __PLAT_FREERTOS__ -// #define RECORD_INSTR_INVALID -#define SPM_LOWER_BOUND 0 -#define SPM_UPPER_BOUND 0x2EFFFF -#define DDR_LOWER_BOUND 0x280000000 -#define IS_WITHIN_SPM_BOUND(value) \ - (((value) >= SPM_LOWER_BOUND) && ((value) <= SPM_UPPER_BOUND)) -#define IS_WITHIN_DDR_BOUND(value) ((value) >= DDR_LOWER_BOUND) -// 设置 times (0-7 位) -#define TIMES_INVALID_OFFET 0 -// 设置 last_invalid_barrier_id (8-35 位) -#define LAST_INVALID_BARRIER 8 -// 设置 first_invalid_barrier_id (36-63 位) -#define FIRST_INVALID_BARRIER 36 - -typedef struct InstrInvalidInfo { - volatile uint64_t ne_error_info; - volatile uint64_t ct_error_info; - volatile uint64_t td_error_info; - volatile uint64_t rdma_error_info; - volatile uint64_t wdma_error_info; -} InstrInvalidInfo; - -/* - # 0-shape(nhwc) # 1-wshape(Kx,Ky,f,c) # 2-bias # 3-stride(Kx,Ky,Sx,Sy) - # 4-pad(top,bottom,left,right) # 5- dilation(0,0,dilation[0],dilation[1]) -*/ -/*=================================TDMA=================================*/ - -/*=================================RDMA WDMA=================================*/ - -/*=================================Scale=================================*/ - -/*=================================run=================================*/ -uint32_t __execute_ne(TsmNeInstr *instr); -uint32_t __execute_ct(TsmArithInstr *instr); -uint32_t __execute_td(TsmDataMoveInstr *instr); -uint32_t __execute_rdma(TsmRdmaInstr *instr); -uint32_t __execute_wdma(TsmWdmaInstr *instr); -void __execute_sc(SC_Param *instr); -uint64_t TsmExecute(void *instr); - -/*=================================debug=================================*/ -void set_device_ddr_base(uint64_t base); -uint64_t get_device_ddr_base(); - -#endif /*INSTR_ADAPTER_H*/ diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h b/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h deleted file mode 100644 index f01828512..000000000 --- a/third_party/tsingmicro/crt/include/Tx81/instr_adapter_plat.h +++ /dev/null @@ -1,873 +0,0 @@ -#ifndef INSTR_ADAPTER_PLAT_H -#define INSTR_ADAPTER_PLAT_H - -// You should define something, according to your device-type - -// ==================== if you run in Tx8-simulator -// ===================================================== - -#include - -// #include "oplib_depend_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct Data_Shape { - uint16_t n; - uint16_t h; - uint16_t w; - uint16_t c; -} Data_Shape; - -typedef struct St_Elem_Shape { - uint32_t elem_count; - uint32_t unit_elem_count; - uint32_t full_elem_count; - uint32_t full_unit_elem_count; -} St_Elem_Shape; - -typedef struct St_StrideIteration { - uint32_t stride0; - uint32_t iteration0; - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; -} St_StrideIteration; - -/*=================================C class=================================*/ -typedef struct TsmConv { - void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, - Data_Format fmt); - void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, - Data_Format fmt); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, - Data_Format fmt); - void (*SetOpType)(TsmNeInstr *instr, uint8_t type); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t scale_addr); //- negative axis - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t scale_addr); //+ positive axis - void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); - void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, - Data_Format fmt); - void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, - uint32_t left, uint32_t right); - void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, - uint32_t left, uint32_t right); - void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, - uint32_t Sx, uint32_t Sy); - void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, - uint8_t zp_cur); - /* data */ -} TsmConv; - -typedef struct TsmDepthwiseConv { - void (*AddInput)(TsmNeInstr *instr, uint64_t X_addr, Data_Shape shape, - Data_Format fmt); - void (*AddWeight)(TsmNeInstr *instr, uint64_t W_addr, Data_Shape shape, - Data_Format fmt); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t bias_addr); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Shape shape, - Data_Format fmt); - void (*SetOpType)(TsmNeInstr *instr, uint8_t type); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t scale_addr); //- negative axis - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t scale_addr); //+ positive axis - void (*SetSparse)(TsmNeInstr *instr, uint8_t sparse_en, uint64_t sparse_addr); - void (*SetPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, - uint32_t left, uint32_t right); - void (*SetUnPads)(TsmNeInstr *instr, uint32_t top, uint32_t bottom, - uint32_t left, uint32_t right); - void (*SetKernelStrides)(TsmNeInstr *instr, uint32_t Kx, uint32_t Ky, - uint32_t Sx, uint32_t Sy); - void (*SetDilations)(TsmNeInstr *instr, uint32_t d0, uint32_t d1); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_pre, - uint8_t zp_cur); - /* data */ -} TsmDepthwiseConv; -typedef struct TsmGemm { - void (*AddInput)(TsmNeInstr *instr, uint64_t L_addr, uint64_t R_addr, - Data_Format in_fmt); - void (*ConfigMKN)(TsmNeInstr *instr, uint32_t M, uint32_t K, uint32_t N); - void (*ConfigBatch)(TsmNeInstr *instr, uint32_t Left_batch, - uint32_t Right_batch); - void (*AddOutput)(TsmNeInstr *instr, uint64_t Out_addr, Data_Format Out_fmt); - void (*SetPsum)(TsmNeInstr *instr, uint8_t psum_en, uint64_t psum_addr, - Data_Format fmt); - void (*SetTransflag)(TsmNeInstr *instr, uint8_t L_trans, uint8_t R_trans); - void (*SetQuant)(TsmNeInstr *instr, uint8_t q0, uint8_t q1, uint8_t zp_left, - uint8_t zp_right); - void (*AddBias)(TsmNeInstr *instr, uint8_t bias_en, uint64_t addr); - void (*SetNegativeAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t addr); - void (*SetPositiveAxisScale)(TsmNeInstr *instr, uint8_t scale_en, - uint64_t addr); - void (*EnableRelu)(TsmNeInstr *instr); - void (*EnableLeakyRelu)(TsmNeInstr *instr); - void (*DisableRelu)(TsmNeInstr *instr); - void (*DisableLeakyRelu)(TsmNeInstr *instr); - - /* data */ -} TsmGemm; -typedef struct TsmRdma { - void (*AddSrcDst)(TsmRdmaInstr *instr, uint64_t src, uint64_t dst, - Data_Format fmt); - void (*ConfigStrideIteration)(TsmRdmaInstr *instr, uint32_t elem_count, - uint32_t stride0, uint32_t iteration0, - uint32_t stride1, uint32_t iteration1, - uint32_t stride2, uint32_t iteration2); - void (*Rdma1d)( - TsmRdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, - uint32_t format); // 只有stride0,和iteration0,内层循环, 只复制一次 -} TsmRdma; - -typedef struct TsmWdma { - void (*AddSrcDst)(TsmWdmaInstr *instr, uint64_t src, uint64_t dst, - Data_Format fmt); - void (*ConfigStrideIteration)(TsmWdmaInstr *instr, uint32_t elem_count, - uint32_t stride0, uint32_t iteration0, - uint32_t stride1, uint32_t iteration1, - uint32_t stride2, uint32_t iteration2); - void (*Wdma1d)( - TsmWdmaInstr *instr, uint64_t src, uint64_t dst, uint32_t elem_count, - uint32_t format); // 只有stride0,和iteration0,内层循环, 只复制一次 -} TsmWdma; - -/*=================================CGRA=================================*/ -typedef struct TsmArith { - void (*AbsVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*RecipVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*SquareVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*SqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*RsqrtVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*NegVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*MaxVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, - Data_Format fmt); - void (*MaxVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, - Data_Format fmt); - void (*MaxVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); - void (*MaxVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE reserved, - Data_Format fmt); - void (*MinVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, - Data_Format fmt); - void (*MinVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE reserved, - Data_Format fmt); - void (*MinVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE reserved, Data_Format fmt); - void (*MinVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE reserved, - Data_Format fmt); - void (*AddVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*AddVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*AddVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*AddVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, - Data_Format fmt); - void (*SubVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*SubVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*SubVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*SubVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, - Data_Format fmt); - void (*MulVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*MulVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*MulVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*MulVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, - Data_Format fmt); - void (*DivVV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*DivVS)(TsmArithInstr *instr, uint64_t src0_addr, uint32_t const_value, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode, - Data_Format fmt); - void (*DivVuV)(TsmArithInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, RND_MODE rnd_mode, Data_Format fmt); - void (*DivVuVLoop)(TsmArithInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, RND_MODE rnd_mode, - Data_Format fmt); -} TsmArith; - -typedef struct TsmRelation { - void (*EqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*BoolEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*EqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*BoolEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*EqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*BoolEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*EqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - void (*BoolEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - - void (*UnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*BoolUnEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*UnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolUnEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*UnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*BoolUnEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*UnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - void (*BoolUnEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*GreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*GreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*GreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*BoolGreaterEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*GreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*BoolGreaterEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*GreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*BoolGreaterVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*GreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolGreaterVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*GreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*BoolGreaterVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*GreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - void (*BoolGreaterVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*LessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolLessEqualVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*LessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolLessEqualVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*LessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*BoolLessEqualVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*LessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - void (*BoolLessEqualVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*LessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*BoolLessThenVV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*LessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*BoolLessThenVS)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t convst_value, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*LessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*BoolLessThenVuV)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - Data_Format fmt); - void (*LessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num, - Data_Format fmt); - void (*BoolLessThenVuVLoop)(TsmRelationInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); -} TsmRelation; - -typedef struct TsmLogic { - void (*NotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*AndVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*OrVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*XorVV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*AndVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*OrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*XorVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, Data_Format fmt); - void (*AndVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*OrVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, uint64_t unit_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - void (*XorVuVLoop)(TsmArithInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count, uint32_t full_elem_num, - uint32_t full_unit_elem_num, Data_Format fmt); - - void (*BoolNotV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count); - void (*BoolAndV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BoolOrV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BoolXorV)(TsmLogicInstr *instr, uint64_t src0_addr, uint64_t src1_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BoolAndVuV)(TsmLogicInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count); - void (*BoolOrVuV)(TsmLogicInstr *instr, uint64_t src_addr, uint64_t unit_addr, - uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count); - void (*BoolXorVuV)(TsmLogicInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, uint32_t elem_count, - uint32_t unit_elem_count); - void (*BoolAndVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num); - void (*BoolOrVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num); - void (*BoolXorVuVLoop)(TsmLogicInstr *instr, uint64_t src_addr, - uint64_t unit_addr, uint64_t dst_addr, - uint32_t elem_count, uint32_t unit_elem_count, - uint32_t full_elem_num, uint32_t full_unit_elem_num); -} TsmLogic; - -typedef struct TsmTranscendental { - void (*Log2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Ln)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Pow2)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Exp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Explp)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Sin)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Cos)(TsmArithInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); -} TsmTranscendental; - -typedef struct TsmActivation { - void (*Tanh)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Sigmoid)(TsmActivationInstr *instr, uint64_t src_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Relu)(TsmActivationInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); - void (*Satrelu)(TsmActivationInstr *instr, uint64_t src_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Leakyrelu)(TsmActivationInstr *instr, uint64_t src_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*Softplus)(TsmActivationInstr *instr, uint64_t src_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); -} TsmActivation; - -typedef struct TsmReduce { - void (*ReduceSum)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t dim, Data_Shape shape, Data_Format fmt); - void (*ReduceAvg)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t dim, Data_Shape shape, Data_Format fmt); - void (*ReduceMax)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t dim, Data_Shape shape, Data_Format fmt); - void (*ReduceMin)(TsmReduceInstr *instr, uint64_t src_addr, uint64_t dst_addr, - uint32_t dim, Data_Shape shape, Data_Format fmt); -} TsmReduce; - -typedef struct TsmPool { - void (*MaxPool)(TsmPoolInstr *instr, uint64_t src0, Data_Shape src_shape, - uint64_t dst, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); - void (*AvgPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, - uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); - void (*SumPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, - uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); - void (*MinPool)(TsmPoolInstr *instr, uint64_t src0_addr, Data_Shape src_shape, - uint64_t dst_addr, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); - void (*IndexdMinPool)(TsmPoolInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_arg, - uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); - void (*IndexdMaxPool)(TsmPoolInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_arg, - uint64_t dst_idx, Data_Shape pad, Data_Shape swr_shape, - Data_Format fmt); -} TsmPool; - -typedef struct TsmUnPool { - void (*Unpool)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, - uint64_t dst_addr, Data_Shape dst_shape, Data_Shape swr_shape, - Data_Format fmt); - void (*UnpoolAvg)(TsmUnPoolInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, Data_Shape dst_shape, - Data_Shape swr_shape, Data_Format fmt); - void (*UnpoolIdx)(TsmUnPoolInstr *instr, uint64_t src0_addr, uint32_t index, - uint64_t dst_addr, Data_Shape dst_shape, - Data_Shape swr_shape, Data_Format fmt); -} TsmUnPool; - -typedef struct TsmMaskDataMove { - void (*MaskMove)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, - uint32_t mask, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*MaskGather)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, - uint32_t index, uint64_t dst_addr, uint32_t elem_count, - Data_Format fmt); - void (*MaskGather_bV)(TsmMaskDataMoveInstr *instr, uint64_t src0_addr, - uint32_t bitindex, uint64_t dst_addr, - uint32_t elem_count, Data_Format fmt); -} TsmMaskDataMove; - -typedef struct TsmConvert { - void (*INT8_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, - uint64_t dst_addr, - uint32_t elem_count); // Data_Format fmt is INT8 - void (*INT8_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, - uint64_t dst_addr, uint32_t elem_count); - void (*INT8_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, - uint64_t dst_addr, uint32_t elem_count); - void (*INT8_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, uint32_t zp, - uint64_t dst_addr, uint32_t elem_count); - void (*INT16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); // INT16 - void (*INT16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*INT32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*INT32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*BF16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BF16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*BF16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*BF16_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BF16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*BF16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - - void (*FP16_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP16_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode); // rnd_mode 0~4 - void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*FP16_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - - void (*FP32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode); // rnd_mode 0~4 - void (*FP32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*FP32_TF32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - - void (*TF32_INT8)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_INT16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_INT32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); - void (*TF32_FP16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); - void (*TF32_BF16)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, - RND_MODE rnd_mode); // rnd_mode 0~4 - void (*TF32_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count); -} TsmConvert; - -typedef struct TsmPeripheral { - void (*Count)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt, - uint64_t *wb_data0, uint64_t *wb_data1); - void (*Memset)(TsmDataMoveInstr *instr, uint64_t dst_addr, uint32_t value, - uint32_t elem_count, St_StrideIteration *si, - Data_Format fmt); // si.stride is byte size. but ele_count is - // only element count - void (*Bit2Fp)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t elem_count, Data_Format fmt); - void (*ArgMax)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint32_t elem_count, Data_Format fmt); - void (*ArgMin)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint32_t elem_count, Data_Format fmt); - void (*Bilinear)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint64_t dst0_addr, Data_Shape src_shape, - Data_Shape dst_shape, int32_t scale_w, int32_t scale_h, - Data_Format fmt); - void (*Lut16)(TsmPeripheralInstr *instr, uint64_t src1_addr, - uint64_t dst0_addr, uint64_t lut16_addr, - uint32_t src_elem_count, uint32_t lut_elem_count); - void (*Lut32)(TsmPeripheralInstr *instr, uint64_t src1_addr, - uint64_t dst0_addr, uint64_t lut32_addr, - uint32_t src_elem_count, uint32_t lut_elem_count); - void (*RandGen)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint64_t src1_addr, uint64_t dst_addr, uint64_t dst1_addr, - uint64_t dst2_addr, uint32_t src_elem_num, Data_Format fmt); - void (*Factorize)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint64_t dst1_addr, uint64_t dst2_addr, - uint32_t src_elem_num); - void (*ElemMask)(TsmPeripheralInstr *instr, uint64_t src0_addr, - uint32_t scale, uint64_t dst_addr, uint32_t src_elem_num, - Data_Format fmt, uint32_t prob, RND_MODE rnd_mode); -} TsmPeripheral; - -typedef struct TsmDataMove { - void (*Mirror)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, - Data_Format fmt); - void (*Transpose)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Rotate90)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Rotate180)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Rotate270)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Nchw2nhwc)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Nhwc2nchw)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*Concat)(TsmMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape0, - uint64_t src1_addr, Data_Shape src_shape1, uint64_t dst_addr, - Data_Shape dst_shape, uint32_t dims, Data_Format fmt); - void (*Pad)(TsmDataMoveInstr *instr, uint64_t src0_addr, Data_Shape src_shape, - uint64_t dst_addr, Data_Shape dst_shape, Data_Shape pad, - Data_Format fmt); - void (*Img2col)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, Data_Shape dst_shape, - uint64_t src_elem_num, uint64_t dst_elem_num, Data_Shape swr, - Data_Shape pdr, Data_Format fmt); - void (*TensorNom)(TsmDataMoveInstr *instr, uint64_t src0_addr, - Data_Shape src_shape, uint64_t dst_addr, - Data_Shape dst_shape, Data_Format fmt); - void (*GatherScatter)(TsmDataMoveInstr *instr, uint64_t src0_addr, - uint64_t dst_addr, uint32_t size, - St_StrideIteration *src_si, St_StrideIteration *dst_si); -} TsmDataMove; - -TsmConv *TsmNewConv(); -TsmDepthwiseConv *TsmNewDepthwiseConv(); -TsmGemm *TsmNewGemm(); -TsmRdma *TsmNewRdma(); -TsmWdma *TsmNewWdma(); -TsmArith *TsmNewArith(); -TsmRelation *TsmNewRelation(); -TsmLogic *TsmNewLogic(); -TsmTranscendental *TsmNewTranscendental(); -TsmActivation *TsmNewActivation(); -TsmReduce *TsmNewReduce(); -TsmPool *TsmNewPool(); -TsmUnPool *TsmNewUnPool(); -TsmMaskDataMove *TsmNewMaskDataMove(); -TsmConvert *TsmNewConvert(); -TsmPeripheral *TsmNewPeripheral(); -TsmDataMove *TsmNewDataMove(); - -void TsmDeleteConv(TsmConv *obj); -void TsmDeleteDepthwiseConv(TsmDepthwiseConv *obj); -void TsmDeleteGemm(TsmGemm *obj); -void TsmDeleteRdma(TsmRdma *obj); -void TsmDeleteWdma(TsmWdma *obj); -void TsmDeleteArith(TsmArith *obj); -void TsmDeleteRelation(TsmRelation *obj); -void TsmDeleteLogic(TsmLogic *obj); -void TsmDeleteTranscendental(TsmTranscendental *obj); -void TsmDeleteActivation(TsmActivation *obj); -void TsmDeleteReduce(TsmReduce *obj); -void TsmDeletePool(TsmPool *obj); -void TsmDeleteUnPool(TsmUnPool *obj); -void TsmDeleteMaskDataMove(TsmMaskDataMove *obj); -void TsmDeleteConvert(TsmConvert *obj); -void TsmDeletePeripheral(TsmPeripheral *obj); -void TsmDeleteDataMove(TsmDataMove *obj); -/*=================================STREAM=================================*/ -typedef struct TsmStream { - uint32_t (*OnlineStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, - uint64_t stream_id, uint64_t stream_addr); - uint32_t (*OfflineStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, - uint32_t remote, uint32_t stream_type, - uint64_t stream_id, uint64_t stream_addr); - uint32_t (*WaitStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, uint32_t remote, - uint32_t stream_type, uint64_t stream_id, - uint64_t stream_addr); - uint32_t (*ReqStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, uint32_t remote, - uint32_t stream_type, uint64_t stream_id, - uint64_t stream_addr); - uint32_t (*PushStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, uint32_t remote, - uint32_t stream_type, uint64_t stream_id, - uint64_t stream_addr); - uint32_t (*PopStream)(uint32_t core_id_this, uint32_t tile_id, - uint32_t core_id, uint32_t channel_id, uint32_t remote, - uint32_t stream_type, uint64_t stream_id, - uint64_t stream_addr); - uint8_t (*wait_finish)(); -} TsmStream; -TsmStream *TsmNewStream(); -void TsmDeleteStream(TsmStream *obj); -/*=================================CSR=====================================*/ -uint8_t TsmWaitfinish(); -uint8_t TsmGetCsrTaskstatus(); -uint8_t TsmGetCsrIbcounter(); -uint8_t TsmGetCsrTaskstatus_bywork(size_t workerid); -uint8_t TsmWaitfinish_bywork(size_t workerid); -/*=================================CSR END=================================*/ -#ifdef __cplusplus -} -#endif - -// ==================== if you will use Tx8-Oplib -// ====================================================== #define LOG_PRINT(...) -// #define LOG_ERR(fmt, args...) -// #define TSM_FREE free -// #define TSM_MALLOC malloc -// extern void setreg(int index, uint64_t value); -// extern uint64_t getreg(int index); - -// ==================== if you run in SOC-freerots/zebu -// ================================================ #include "rce_log.h" -// #include "csi_kernel.h" -// #include "rce_pal.h" -// #define LOG_PRINT(fmt, args...) vdk_printf(fmt, ##args) -// #define LOG_ERR(fmt, args...) vdk_printf(fmt, ##args) -// #define TSM_FREE(target) csi_kernel_free(2, target, NULL) -// #define TSM_MALLOC(size) csi_kernel_malloc(2, size, NULL) -// #define NCC_ADDR 0x01000000 -// #define setreg(ADDR, VALUE) -// do { -// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, -// value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); -// *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; -// } while (0) - -// ==================== if you run in kernel-rt(use tile-sim) -// ========================================== #define LOG_PRINT(fmt, args...) -// printf(fmt, ##args) #define LOG_ERR(fmt, args...) printf(fmt, ##args) #define -// TSM_FREE free #define TSM_MALLOC malloc #include "rce_pal_port.h" void -// setreg(int index, uint64_t value) -// { -// LOG_PRINT("setreg param: GR: index=0x%X, value=0x%lX(%lu).\n", index, -// value, value); rce_tx_pal_setreg(index, value); -// } - -// ====================if you run in kernel-rt(use riscv) -// =============================================== #define LOG_PRINT(...) -// #define LOG_ERR(fmt, args...) -// #define TSM_FREE free -// #define TSM_MALLOC malloc -// #define NCC_ADDR 0x01000000 -// #define setreg(ADDR, VALUE) -// do { -// LOG_PRINT("(FREERT)setreg params: GR: index=0x%X, -// value=0x%lX(%lu).\n", ADDR, VALUE, VALUE); -// *((volatile uint64_t *)(ADDR + NCC_ADDR)) = VALUE; -// } while (0) - -// ====================if you do not need Log -// ========================================================== -// #define LOG_PRINT(...) -// #define LOG_ERR(fmt, args...) -#endif diff --git a/third_party/tsingmicro/crt/include/Tx81/instr_def.h b/third_party/tsingmicro/crt/include/Tx81/instr_def.h deleted file mode 100644 index 2d73bf8c0..000000000 --- a/third_party/tsingmicro/crt/include/Tx81/instr_def.h +++ /dev/null @@ -1,949 +0,0 @@ -#ifndef _RCE_INSTR_DEF_H_ -#define _RCE_INSTR_DEF_H_ -#include - -#define UN_USED 0 - -// CT -#define GR_CT_CONTROL_ADDR 0x0000 -#define GR_CT_SRC0_ADDR 0x0008 * 2 -#define GR_CT_SRC1_ADDR 0x0010 * 2 -#define GR_CT_DST0_ADDR 0x0018 * 2 -#define GR_CT_DST1_ADDR 0x0020 * 2 -#define GR_CT_DST2_ADDR 0x0028 * 2 -#define GR_CT_DIMS_ADDR 0x0030 * 2 -#define GR_CT_SRC0_TFR_ADDR 0x0038 * 2 -#define GR_CT_DST_TFR_ADDR 0x0040 * 2 -#define GR_CT_PDR_ADDR 0x0048 * 2 -#define GR_CT_SWR_ADDR 0x0050 * 2 -#define GR_CT_ELEM_COUNT_ADDR 0x0058 * 2 -#define GR_CT_UNIT_ELEM_COUNT_ADDR 0x0060 * 2 -#define GR_CT_INT8_SCALE_VAL0_ADDR 0x0068 * 2 -#define GR_CT_INT8_SCALE_VECTOR_ADDR 0x0068 * 2 -#define GR_CT_INT8_SCALE_VAL1_ADDR 0x0070 * 2 -#define GR_CT_INT8_QUANT_ADDR 0x0078 * 2 -#define GR_CT_INT8_BN_ZP_ADDR 0x0080 * 2 -#define GR_CT_FULL_ELEM_COUNT_ADDR 0x0088 * 2 -#define GR_CT_FULL_UNIT_ELEM_COUNT_ADDR 0x0090 * 2 -#define GR_CT_WB_DATA0_ADDR 0x0098 * 2 -#define GR_CT_WB_DATA1_ADDR 0x00A0 * 2 -#define GR_CT_SRC0_END_ADDR 0x00A8 * 2 -#define GR_CT_SRC1_END_ADDR 0x00B0 * 2 -#define GR_CT_DST0_END_ADDR 0x00B8 * 2 -#define GR_CT_DST1_END_ADDR 0x00C0 * 2 -#define GR_CT_DST2_END_ADDR 0x00C8 * 2 - -// NE -#define GR_NE_CONTROL_ADDR 0x0100 * 2 -#define GR_NE_SRC_A_ADDR 0x0108 * 2 -#define GR_NE_SRC_W_ADDR 0x0110 * 2 -#define GR_NE_PSUM_ADDR 0x0118 * 2 -#define GR_NE_BIAS_ADDR 0x0120 * 2 -#define GR_NE_SCALE_P_ADDR 0x0128 * 2 -#define GR_NE_SCALE_N_ADDR 0x0130 * 2 -#define GR_NE_OUT_ADDR 0x0138 * 2 -#define GR_NE_SRC0_TFR_ADDR 0x0140 * 2 -#define GR_NE_SRC1_OUT_TFR_ADDR 0x0148 * 2 -#define GR_NE_PDR_ADDR 0x0150 * 2 -#define GR_NE_UNPDR_ADDR 0x0158 * 2 -#define GR_NE_SWR_ADDR 0x0160 * 2 -#define GR_NE_DILATION_ADDR 0x0168 * 2 -#define GR_NE_GEMM_LB_ADDR 0x0170 * 2 -#define GR_NE_GEMM_RB_ADDR 0x0178 * 2 -#define GR_NE_GEMM_N_ADDR 0x0180 * 2 -#define GR_NE_GEMM_M_ADDR 0x0188 * 2 -#define GR_NE_GEMM_K_ADDR 0x0190 * 2 -#define GR_NE_GEMM_L_TRS_ADDR 0x0198 * 2 -#define GR_NE_GEMM_R_TRS_ADDR 0x01A0 * 2 -#define GR_NE_QUANT_ADDR 0x01A8 * 2 -#define GR_NE_SPARSE_INDEX_ADDR 0x01B0 * 2 -#define GR_NE_SRCA_END 0x370 -#define GR_NE_SRCW_END 0x380 -#define GR_NE_PSUM_END 0x390 -#define GR_NE_BIAS_END 0x3A0 -#define GR_NE_SCALE_P_END 0x3B0 -#define GR_NE_SCALE_N_END 0x3C0 -#define GR_NE_OUT_END 0x3D0 -#define GR_NE_SPARSE_INDEX_END 0x3E0 - -// LSU RDMA -#define GR_RD_CONTROL_ADDR 0x400 -#define GR_RD_SRC_ADDR 0x410 -#define GR_RD_DST_ADDR 0x420 -#define GR_RD_STRIDE_ITERA0_ADDR 0x430 -#define GR_RD_STRIDE_ITERA1_ADDR 0x440 -#define GR_RD_STRIDE_ITERA2_ADDR 0x450 -#define GR_RD_ELEM_COUNT_ADDR 0x460 -#define GR_RD_FORMAT_ADDR 0x470 -#define GR_RD_SRC_END 0x480 -#define GR_RD_DST_END 0x490 -// LSU WDMA -#define GR_WD_CONTROL_ADDR 0x4A0 -#define GR_WD_SRC_ADDR 0x4B0 -#define GR_WD_DST_ADDR 0x4C0 -#define GR_WD_STRIDE_ITERA0_ADDR 0x4D0 -#define GR_WD_STRIDE_ITERA1_ADDR 0x4E0 -#define GR_WD_STRIDE_ITERA2_ADDR 0x4F0 -#define GR_WD_ELEM_COUNT_ADDR 0x500 -#define GR_WD_FORMAT_ADDR 0x510 -#define GR_WD_SRC_END 0x520 -#define GR_WD_DST_END 0x530 - -// TMDA -#define GR_TD_CONTROL_ADDR 0x02A0 * 2 -#define GR_TD_SRC0_ADDR 0x02A8 * 2 -#define GR_TD_SRC1_ADDR 0x02B0 * 2 -#define GR_TD_DST_ADDR 0x02B8 * 2 -#define GR_TD_DIMS_ADDR 0x02C0 * 2 -#define GR_TD_SRC0_TFR_ADDR 0x02C8 * 2 -#define GR_TD_DST_TFR_ADDR 0x02D0 * 2 -#define GR_TD_PDR_ADDR 0x02D8 * 2 -#define GR_TD_SWR_ADDR 0x02E0 * 2 -#define GR_TD_ELEM_COUNT_ADDR 0x02E8 * 2 -#define GR_TD_SRC_STRIDE_ITERA0_ADDR 0x02F0 * 2 -#define GR_TD_SRC_STRIDE_ITERA1_ADDR 0x02F8 * 2 -#define GR_TD_SRC_STRIDE_ITERA2_ADDR 0x0300 * 2 -#define GR_TD_DST_STRIDE_ITERA0_ADDR 0x0308 * 2 -#define GR_TD_DST_STRIDE_ITERA1_ADDR 0x0310 * 2 -#define GR_TD_DST_STRIDE_ITERA2_ADDR 0x0318 * 2 -#define GR_TD_SRC0_END 0x640 -#define GR_TD_SRC1_END 0x650 -#define GR_TD_DST_END 0x660 -// SCALAR -#define GR_SCALAR_CONTROL_ADDR 0x6A0 -#define GR_SCALAR_SRC_ADDR 0x6B0 -#define GR_SCALAR_DST_ADDR 0x6C0 -// CSR -#define GR_CSR_CONTROL_ADDR 0x740 -#define GR_CSR_EXCEPTION_ADDR 0x750 -#define GR_CSR_PRIORITY_ADDR 0x760 -#define GR_CSR_EXCEPTION_MASK_ADDR 0x770 -#define GR_CSR_SERIAL_MODE_ADDR 0x780 - -// CSR end -// DTE start -#define GR_DTE_SRC_ADDR_LO 0x0 -#define GR_DTE_SRC_ADDR_HI 0x4 -#define GR_DTE_DST_ADDR_LO_0 0x8 -#define GR_DTE_DST_ADDR_HI_0 0xC -#define GR_DTE_USER_ID_0 0x10 -#define GR_DTE_MODE 0x14 -#define GR_DTE_LENGTH 0x18 -#define GR_DTE_DEST_NUM 0x1C -#define GR_DTE_STRIDE0 0x20 -#define GR_DTE_ITERATION0 0x24 -#define GR_DTE_STRIDE1 0x28 -#define GR_DTE_ITERATION1 0x2C -#define GR_DTE_STRIDE2 0x30 -#define GR_DTE_ITERATION2 0x34 -#define GR_DTE_CMD_VALID 0x38 -#define GR_DTE_DMA_STATUS 0x40 -#define GR_DTE_DST_ADDR_LO_1 0x50 -#define GR_DTE_DST_ADDR_HI_1 0x54 -#define GR_DTE_DST_ADDR_LO_2 0x58 -#define GR_DTE_DST_ADDR_HI_2 0x5C -#define GR_DTE_DST_ADDR_LO_3 0x60 -#define GR_DTE_DST_ADDR_HI_3 0x64 -#define GR_DTE_DST_ADDR_LO_4 0x68 -#define GR_DTE_DST_ADDR_HI_4 0x6C -#define GR_DTE_DST_ADDR_LO_5 0x70 -#define GR_DTE_DST_ADDR_HI_5 0x74 -#define GR_DTE_DST_ADDR_LO_6 0x78 -#define GR_DTE_DST_ADDR_HI_6 0x7C -#define GR_DTE_DST_ADDR_LO_7 0x80 -#define GR_DTE_DST_ADDR_HI_7 0x84 -#define GR_DTE_DST_ADDR_LO_8 0x88 -#define GR_DTE_DST_ADDR_HI_8 0x8C -#define GR_DTE_DST_ADDR_LO_9 0x90 -#define GR_DTE_DST_ADDR_HI_9 0x94 -#define GR_DTE_DST_ADDR_LO_10 0x98 -#define GR_DTE_DST_ADDR_HI_10 0x9C -#define GR_DTE_DST_ADDR_LO_11 0xA0 -#define GR_DTE_DST_ADDR_HI_11 0xA4 -#define GR_DTE_DST_ADDR_LO_12 0xA8 -#define GR_DTE_DST_ADDR_HI_12 0xAC -#define GR_DTE_DST_ADDR_LO_13 0xB0 -#define GR_DTE_DST_ADDR_HI_13 0xB4 -#define GR_DTE_DST_ADDR_LO_14 0xB8 -#define GR_DTE_DST_ADDR_HI_14 0xBC -#define GR_DTE_DST_ADDR_LO_15 0xC0 -#define GR_DTE_DST_ADDR_HI_15 0xC4 -#define GR_DTE_DST_ADDR_LO_16 0xC8 -#define GR_DTE_DST_ADDR_HI_16 0xCC -#define GR_DTE_DST_ADDR_LO_17 0xD0 -#define GR_DTE_DST_ADDR_HI_17 0xD4 -#define GR_DTE_DST_ADDR_LO_18 0xD8 -#define GR_DTE_DST_ADDR_HI_18 0xD4 -#define GR_DTE_DST_ADDR_LO_19 0xE0 -#define GR_DTE_DST_ADDR_HI_19 0xE4 -#define GR_DTE_DST_ADDR_LO_20 0xE8 -#define GR_DTE_DST_ADDR_HI_20 0xEC -#define GR_DTE_DST_ADDR_LO_21 0xF0 -#define GR_DTE_DST_ADDR_HI_21 0xF4 -#define GR_DTE_DST_ADDR_LO_22 0xF8 -#define GR_DTE_DST_ADDR_HI_22 0xFC -#define GR_DTE_DST_ADDR_LO_23 0x100 -#define GR_DTE_DST_ADDR_HI_23 0x104 -#define GR_DTE_DST_ADDR_LO_24 0x108 -#define GR_DTE_DST_ADDR_HI_24 0x10C -#define GR_DTE_DST_ADDR_LO_25 0x110 -#define GR_DTE_DST_ADDR_HI_25 0x114 -#define GR_DTE_DST_ADDR_LO_26 0x118 -#define GR_DTE_DST_ADDR_HI_26 0x11C -#define GR_DTE_DST_ADDR_LO_27 0x120 -#define GR_DTE_DST_ADDR_HI_27 0x124 -#define GR_DTE_DST_ADDR_LO_28 0x128 -#define GR_DTE_DST_ADDR_HI_28 0x12C -#define GR_DTE_DST_ADDR_LO_29 0x130 -#define GR_DTE_DST_ADDR_HI_29 0x134 -#define GR_DTE_DST_ADDR_LO_30 0x138 -#define GR_DTE_DST_ADDR_HI_30 0x13C -#define GR_DTE_DST_ADDR_LO_31 0x140 -#define GR_DTE_DST_ADDR_HI_31 0x144 - -#define GR_DTE_USER_ID_1 0x148 -#define GR_DTE_USER_ID_2 0x14C -#define GR_DTE_USER_ID_3 0x150 -#define GR_DTE_USER_ID_4 0x154 -#define GR_DTE_USER_ID_5 0x158 -#define GR_DTE_USER_ID_6 0x15C -#define GR_DTE_USER_ID_7 0x160 -#define GR_DTE_USER_ID_8 0x164 -#define GR_DTE_USER_ID_9 0x168 -#define GR_DTE_USER_ID_10 0x16C -#define GR_DTE_USER_ID_11 0x170 -#define GR_DTE_USER_ID_12 0x174 -#define GR_DTE_USER_ID_13 0x178 -#define GR_DTE_USER_ID_14 0x17C -#define GR_DTE_USER_ID_15 0x180 -#define GR_DTE_USER_ID_16 0x184 -#define GR_DTE_USER_ID_17 0x188 -#define GR_DTE_USER_ID_18 0x18C -#define GR_DTE_USER_ID_19 0x190 -#define GR_DTE_USER_ID_20 0x194 -#define GR_DTE_USER_ID_21 0x198 -#define GR_DTE_USER_ID_22 0x19C -#define GR_DTE_USER_ID_23 0x1A0 -#define GR_DTE_USER_ID_24 0x1A4 -#define GR_DTE_USER_ID_25 0x1A8 -#define GR_DTE_USER_ID_26 0x1AC -#define GR_DTE_USER_ID_27 0x1B0 -#define GR_DTE_USER_ID_28 0x1B4 -#define GR_DTE_USER_ID_29 0x1B8 -#define GR_DTE_USER_ID_30 0x1BC -#define GR_DTE_USER_ID_31 0x1C0 - -#define GR_DTE_MAX_AXI_NUM 0x1D0 -#define GR_DTE_MEM_BURSTLEN 0x1D4 -#define GR_DTE_MEM_BACKPRESSURE 0x1D8 -#define GR_DTE_MEM_READ_TURBO 0x1DC -// DTE end - -// SCONFIG begin -#define GR_SCONFIG_GPR0 0x600 - -// SCONFIG end - -// NCC PMU begin -#define GR_PMU_EN 0x0 -#define GR_PMU_CLR 0x4 -#define GR_PMU_STATISTICS_WINDOW 0x8 -#define GR_PMU_CT_INST_NUMS 0x10 -#define GR_PMU_NE_INST_NUMS 0x14 -#define GR_PMU_RDMA_INST_NUMS 0x18 -#define GR_PMU_WDMA_INST_NUMS 0x1C -#define GR_PMU_TDMA_INST_NUMS 0x20 -#define GR_PMU_SCALAR_INST_NUMS 0x24 -#define GR_PMU_CT_BLOCKING_TIME 0x28 -#define GR_PMU_NE_BLOCKING_TIME 0x2C -#define GR_PMU_RDMA_BLOCKING_TIME 0x30 -#define GR_PMU_WDMA_BLOCKING_TIME 0x34 -#define GR_PMU_TDMA_BLOCKING_TIME 0x38 -#define GR_PMU_SCALAR_BLOCKING_TIME 0x3c - -#define GR_PMU_FU_EXE_TIME 0x13c -#define GR_PMU_CT_EXE_TIME 0x144 -#define GR_PMU_NE_EXE_TIME 0x14c -#define GR_PMU_RDMA_EXE_TIME 0x154 -#define GR_PMU_WDMA_EXE_TIME 0x15c -#define GR_PMU_TDMA_EXE_TIME 0x164 -#define GR_PMU_SCALAR_EXE_TIME 0x16c -// NCC PMU end - -// DTE PMU begin -#define DTE_PMU_EN 0x800 -#define DTE_PMU_CLR 0x804 - -#define DTE_PMU_CH0_L_EXE_TIME 0x858 -#define DTE_PMU_CH0_H_EXE_TIME 0x85C -#define DTE_PMU_CH1_L_EXE_TIME 0x860 -#define DTE_PMU_CH1_H_EXE_TIME 0x864 -// DTE PMU end - -typedef enum OP_INSTR_TYPE { - I_CGRA, - I_NEUR, - I_RDMA, - I_WDMA, - I_TDMA, - I_SCALAR, - I_DTE, - I_CSR, -} OP_INSTR_TYPE; -// instr_type = I_CGRA | I_WORKER1 -typedef enum OP_INSTR_WORKER { - I_WORKER0 = 0x0000, - I_WORKER1 = 0x0100, - I_WORKER2 = 0x0200, -} OP_INSTR_WORKER; - -typedef enum RND_MODE { - RND_NEAREST_EVEN, - RND_ZERO, - RND_POS_INF, - RND_NEG_INF, - RND_STOCHASTIC -} RND_MODE; - -typedef struct Ncc_CT_GR_Ctl_Regs { - uint8_t cmd_valid; // self clear - uint8_t rnd_mode; // 0 :round to nearest even , 1 :round to zero, 2 :round to - // positive infinity, 3 :round to negative infinity, 4 - // :stochastic round - uint8_t - src0_format; // 当CGRATensor_PeriOp_V_V_bit2fp指令,此字段用作dst_format - uint8_t opcode; // 详见CGRATensor指令OPcode.v -} Ncc_CT_GR_Ctl_Regs; - -typedef struct Ncc_CT_GR_Param_Regs { - uint32_t src0; // spm地址 - uint32_t src1; - uint32_t dst0; - uint32_t dst1; - uint32_t dst2; // spm地址 - uint64_t src0_tfr; // nhwc - uint64_t dst_tfr; // nhwc - uint64_t pdr; // TOP BOTTOM,LEFT,RIGHT(分别是上下左右pad的行/列数) - uint64_t swr; // kernel的 Kx(x方向的大小),Ky,Sx(x方向的步进),Sy - uint64_t elem_count; // vector运算的元素个数 - uint64_t unit_elem_count; // vector运算中的短向量的元素个数(最大为64) - uint64_t int8_scale_val0; // 双线性插值x方向缩放系数(input_w/output_w) - uint64_t int8_scale_val1; // 双线性插值y方向缩放系数(input_h/output_h) - uint64_t int8_quant; // abandon - uint32_t int8_bn_bias; // abandon - uint32_t full_elem_count; // 若干个src_elem_num之和 - uint32_t full_unit_elem_count; // 若干个src_uint_elem_num之和 - uint64_t wb_data0; // The pointer of Return value. [32] DATA_VALID, [31:0] - // data, 函数只有一个返回值时,返回数据写在此寄存器 - uint64_t - wb_data1; // The pointer of Return value. [32] DATA_VALID, [31:0] data, - // 函数有两个返回值时,第二个返回数据写在此寄存器,当只有一个返回值时,此寄存器无效 - uint32_t src0_end; // spm地址(src0结束地址), xxx_end = src/dst + - // 对应操作数在spm中存储范围 - uint32_t src1_end; - uint32_t dst0_end; - uint32_t dst1_end; - uint32_t dst2_end; - uint8_t dims; // 000:C 001:W 010:H 011:N 100:HW 101:HWC -} Ncc_CT_GR_Param_Regs; - -typedef struct CT_Param { - uint32_t inter_type; - Ncc_CT_GR_Ctl_Regs ctrl; - Ncc_CT_GR_Param_Regs param; -} CT_Param; - -#define TsmArithInstr CT_Param -#define TsmPoolInstr CT_Param -#define TsmMoveInstr CT_Param -#define TsmUnPoolInstr CT_Param -#define TsmMaskDataMoveInstr CT_Param -#define TsmConvertInstr CT_Param -#define TsmPeripheralInstr CT_Param -#define TsmRelationInstr CT_Param -#define TsmLogicInstr CT_Param -#define TsmTranscendentalInstr CT_Param -#define TsmActivationInstr CT_Param -#define TsmReduceInstr CT_Param - -typedef struct Ncc_NE_GR_Ctl_Regs { - uint8_t sparse_en; - uint8_t cmd_valid; - uint8_t inpsum_format; - uint8_t output_format; - uint8_t input_format; - uint8_t inpsum_en; - uint8_t lrelu_en; // either relu or lrelu - uint8_t relu_en; // relu_en/lrelu_en/bias_en/scale_en 同时为0时,输出是psum - uint8_t scale_en; - uint8_t bias_en; - uint8_t dilation_conv; // valid as conv backwardconv - uint8_t type; // 0:conv 1:depthwise conv 2:backward conv 3:gemm -} Ncc_NE_GR_Ctl_Regs; - -typedef struct Ncc_NE_GR_Param_Regs { - uint32_t src_a; // spm地址(激活/左矩阵) - uint32_t src_w; // spm地址(权重/右矩阵) - uint32_t psum; // spm地址(输入psum) - uint32_t bias; // spm地址(bias) - uint32_t scale_p; // spm地址(正轴scale) - uint32_t scale_n; // spm地址(负轴scale) - uint32_t out; // spm地址(输出psum) - uint64_t tfr_0; // src0 nhwc, [15:0]tensor - // batch/h/w(范围1~4096);tensor通道数(范围1~16384) - uint64_t tfr_1; // conv: out nhwc, 同上tfr_0 - uint64_t pdr; // pad [15:0]top bottom left right, - // 分别是上下左右pad的行/列数(范围0~1023) - uint64_t unpdr; // unpad [15:0]top bottom left right - uint64_t - swr; // [15:0]Kx(范围1~255) Ky(范围1~255) Sx(范围1~1023) Sy(范围1~1023) - uint64_t dilation; // [15:0]空洞卷积的x方向大小(范围1-1023), - // [15:0]空洞卷积的y方向大小(范围1-1023) - - uint16_t gemm_lb; // [15:0]左矩阵batch(范围:1~4096) - uint16_t gemm_rb; // [15:0]左矩阵batch(范围:1~4096) - uint16_t gemm_n; // 矩阵运算的矩阵大小参数 - uint16_t gemm_m; // mk*kn---->mn - uint16_t gemm_k; // (范围:1~16384) - uint8_t gemm_l_trs; // 左矩阵转置 - uint8_t gemm_r_trs; // 右矩阵转置 - /* - Quant formula----A_int8:Left input, B_int8: Right input - Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 - Left input 8bit to 9bit: A_int9 = A_int8 - ZP_A_int8 - do conv : O_int32 = Sum_{A_int9 * B_int9} - do scale : O_int16 = Clip_int16(O_int32 >> q1) - do scale : O_int9 = Clip_int9((O_int16 * S_int16) >> q2) - out 9bit to 8bit : O_int8 = O_int9 + ZP_O_int8 - */ - uint8_t quant_zp_cur; // 输出零点(0-255). [39:32] - uint8_t quant_reserved; // (0-255). [31:24] conv:unused gemm:right_zp - uint8_t quant_zp_pre; // 输入零点(0-255), [23:16] conv:act_zp - // gemm:left_zp(范围:0-255) - uint8_t quant_q1; // q1, (范围:0-31),[15:8] - uint8_t quant_q0; // q2, (范围:0-31),[7:0] - - uint32_t sparse_index; // spm地址(稀疏化索引) - uint32_t srca_end; // xxx_end = src/dst + 对应操作数在spm中存储范围 - uint32_t srcw_end; - uint32_t psum_end; - uint32_t bias_end; - uint32_t scale_p_end; - uint32_t scale_n_end; - uint32_t out_end; - uint32_t sparse_end; -} Ncc_NE_GR_Param_Regs; - -typedef struct TsmNeInstr { - uint32_t inter_type; - Ncc_NE_GR_Ctl_Regs ctrl; - Ncc_NE_GR_Param_Regs param; -} TsmNeInstr; - -// RDMA / WDMA -typedef struct Ncc_DMA_GR_Ctl_Regs { - uint8_t cmd_valid; -} Ncc_DMA_GR_Ctl_Regs; - -typedef struct Ncc_DMA_GR_Param_Regs { - uint64_t dst; // ddr地址 - uint64_t src; // spm地址 - /* - for(i = 0; i < itera2; i++) - for(j = 0; j < itera1; j++) - for(k = 0; k < itera0; k++) - for(l = 0; l < elem_count; l++) - dst[l + elem_coun * k + elem_coun * src_itera0 * j + - elem_coun * src_itera0 * src_itera1 * i] = \ src[l + k * src_stride0 + j * - src_stride1 + i * src_stride2]; - */ - uint32_t stride0; // 地址步长 - uint32_t iteration0; // 数据块个数 - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; - uint32_t elem_count; // 最里面维度单次搬运的元素个数 - uint8_t format; // 数据类型 - uint64_t src_end; // src_end = src + ddr中数据存储长度 - uint64_t dst_end; // dst_end = dst + spm中数据存储长度 -} Ncc_DMA_GR_Param_Regs; - -typedef struct DMA_Param { - uint32_t inter_type; - Ncc_DMA_GR_Ctl_Regs ctrl; - Ncc_DMA_GR_Param_Regs param; -} DMA_Param; - -#define TsmRdmaInstr DMA_Param -#define TsmWdmaInstr DMA_Param - -typedef struct Ncc_TDMA_GR_Ctl_Regs { - uint8_t cmd_valid; // [12] - uint8_t src0_format; // [11:8] - uint8_t opcode; //[7:0] -} Ncc_TDMA_GR_Ctl_Regs; - -typedef struct Ncc_TDMA_GR_Param_Regs { - uint32_t src0; - uint32_t src1; - uint32_t dst; - uint64_t src0_tfr; // nhwc c:15~0 - uint64_t dst_tfr; // nhwc - uint64_t pdr; // top bottom left right - uint64_t swr; // kx ky sx sy - uint32_t elem_count; // vector操作的元素个数. - // memset、gatherscatter指令中代表byte number - /* - for(i=0;igather/scatter, unicast, 1-> scatter, - // broadcast, 3-> shuffle(3D gather). [8:8] sg_flag: - // 0->scatter, 1->gather, [16:16] dim_flag: only unicast - // mode(=0), 0->1D transport, 1->2D transport - uint32_t length; // count data bytes - uint8_t dest_num; // if mode[0:0] is 0, then it's value is 1; otherwise it's - // value is between 1 and 31. - uint32_t - stride0; // if mode[0:0] is 0, then stride can be setted, unit is byte. - uint32_t iteration0; // 0: means 1 section; 1: means 2 sectons, and so on. - uint32_t stride1; - uint32_t iteration1; - uint32_t stride2; - uint32_t iteration2; - uint16_t - max_axi_num; // [7:0] axi_write_outstanding, [15:8] aix_read_outstanding - uint8_t cmd_valid; // 1: activate dma, 0: no action. - // uint8_t dma_status; // [0:0] 0->unfinished, 1->finished. [8:8] 0/1, record - // the error of AXI bus or other DMA transmission. - uint16_t mem_burstlen; // [7:0] mem_burst_len_write, default value: 0x10; - // [15:8] mem_burst_len_read, default value: 0x10 - uint8_t mem_backpressure; // 0x1 - uint8_t mem_read_turbo; // [1:0], 0~2, default value: 0, only block0 valid. -} Ncc_DTE_GR_Param_Regs; - -typedef enum OP_FUNC_CGRA { - // Arithmetic Operators - OP_FUNC_CGRATensor_ArithOp_V_V_abs = 0, - OP_FUNC_CGRATensor_ArithOp_V_V_recip = 1, - OP_FUNC_CGRATensor_ArithOp_V_V_square = 2, - OP_FUNC_CGRATensor_ArithOp_V_V_sqrt = 3, - OP_FUNC_CGRATensor_ArithOp_V_V_rsqrt = 4, - OP_FUNC_CGRATensor_ArithOp_V_V_neg = 5, - OP_FUNC_CGRATensor_ArithOp_V_VV_max = 6, - OP_FUNC_CGRATensor_ArithOp_V_VS_max = 7, - OP_FUNC_CGRATensor_ArithOp_V_VuV_max = 8, - OP_FUNC_CGRATensor_ArithOp_V_VuV_max_loop = 9, - OP_FUNC_CGRATensor_ArithOp_V_VV_min = 10, - OP_FUNC_CGRATensor_ArithOp_V_VS_min = 11, - OP_FUNC_CGRATensor_ArithOp_V_VuV_min = 12, - OP_FUNC_CGRATensor_ArithOp_V_VuV_min_loop = 13, - OP_FUNC_CGRATensor_ArithOp_V_VV_add = 14, - OP_FUNC_CGRATensor_ArithOp_V_VS_add = 15, - OP_FUNC_CGRATensor_ArithOp_V_VuV_add = 16, - OP_FUNC_CGRATensor_ArithOp_V_VuV_add_loop = 17, - OP_FUNC_CGRATensor_ArithOp_V_VV_sub = 18, - OP_FUNC_CGRATensor_ArithOp_V_VS_sub = 19, - OP_FUNC_CGRATensor_ArithOp_V_VuV_sub = 20, - OP_FUNC_CGRATensor_ArithOp_V_VuV_sub_loop = 21, - OP_FUNC_CGRATensor_ArithOp_V_VV_mul = 22, - OP_FUNC_CGRATensor_ArithOp_V_VS_mul = 23, - OP_FUNC_CGRATensor_ArithOp_V_VuV_mul = 24, - OP_FUNC_CGRATensor_ArithOp_V_VuV_mul_loop = 25, - OP_FUNC_CGRATensor_ArithOp_V_VV_div = 26, - OP_FUNC_CGRATensor_ArithOp_V_VS_div = 27, - OP_FUNC_CGRATensor_ArithOp_V_VuV_div = 28, - OP_FUNC_CGRATensor_ArithOp_V_VuV_div_loop = 29, - - // Relational Operators - OP_FUNC_CGRATensor_RelaOp_V_VV_eq = 30, - OP_FUNC_CGRATensor_RelaOp_bV_VV_eq = 31, - OP_FUNC_CGRATensor_RelaOp_V_VS_eq = 32, - OP_FUNC_CGRATensor_RelaOp_bV_VS_eq = 33, - OP_FUNC_CGRATensor_RelaOp_V_VuV_eq = 34, - OP_FUNC_CGRATensor_RelaOp_V_VuV_eq_loop = 35, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq = 36, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_eq_loop = 37, - - OP_FUNC_CGRATensor_RelaOp_V_VV_ne = 38, - OP_FUNC_CGRATensor_RelaOp_bV_VV_ne = 39, - OP_FUNC_CGRATensor_RelaOp_V_VS_ne = 40, - OP_FUNC_CGRATensor_RelaOp_bV_VS_ne = 41, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ne = 42, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ne_loop = 43, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne = 44, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ne_loop = 45, - - OP_FUNC_CGRATensor_RelaOp_V_VV_ge = 46, - OP_FUNC_CGRATensor_RelaOp_bV_VV_ge = 47, - OP_FUNC_CGRATensor_RelaOp_V_VS_ge = 48, - OP_FUNC_CGRATensor_RelaOp_bV_VS_ge = 49, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ge = 50, - OP_FUNC_CGRATensor_RelaOp_V_VuV_ge_loop = 51, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge = 52, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_ge_loop = 53, - - OP_FUNC_CGRATensor_RelaOp_V_VV_gt = 54, - OP_FUNC_CGRATensor_RelaOp_bV_VV_gt = 55, - OP_FUNC_CGRATensor_RelaOp_V_VS_gt = 56, - OP_FUNC_CGRATensor_RelaOp_bV_VS_gt = 57, - OP_FUNC_CGRATensor_RelaOp_V_VuV_gt = 58, - OP_FUNC_CGRATensor_RelaOp_V_VuV_gt_loop = 59, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt = 60, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_gt_loop = 61, - - OP_FUNC_CGRATensor_RelaOp_V_VV_le = 62, - OP_FUNC_CGRATensor_RelaOp_bV_VV_le = 63, - OP_FUNC_CGRATensor_RelaOp_V_VS_le = 64, - OP_FUNC_CGRATensor_RelaOp_bV_VS_le = 65, - OP_FUNC_CGRATensor_RelaOp_V_VuV_le = 66, - OP_FUNC_CGRATensor_RelaOp_V_VuV_le_loop = 67, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_le = 68, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_le_loop = 69, - - OP_FUNC_CGRATensor_RelaOp_V_VV_lt = 70, - OP_FUNC_CGRATensor_RelaOp_bV_VV_lt = 71, - OP_FUNC_CGRATensor_RelaOp_V_VS_lt = 72, - OP_FUNC_CGRATensor_RelaOp_bV_VS_lt = 73, - OP_FUNC_CGRATensor_RelaOp_V_VuV_lt = 74, - OP_FUNC_CGRATensor_RelaOp_V_VuV_lt_loop = 75, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt = 76, - OP_FUNC_CGRATensor_RelaOp_bV_VuV_lt_loop = 77, - - OP_FUNC_CGRATensor_LogicOp_V_V_not = 78, - OP_FUNC_CGRATensor_LogicOp_V_VV_and = 79, - OP_FUNC_CGRATensor_LogicOp_V_VV_or = 80, - OP_FUNC_CGRATensor_LogicOp_V_VV_xor = 81, - OP_FUNC_CGRATensor_LogicOp_V_VuV_and = 82, - OP_FUNC_CGRATensor_LogicOp_V_VuV_or = 83, - OP_FUNC_CGRATensor_LogicOp_V_VuV_xor = 84, - OP_FUNC_CGRATensor_LogicOp_V_VuV_and_loop = 85, - OP_FUNC_CGRATensor_LogicOp_V_VuV_or_loop = 86, - OP_FUNC_CGRATensor_LogicOp_V_VuV_xor_loop = 87, - - OP_FUNC_CGRATensor_LogicOp_bV_bV_not = 88, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_and = 89, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_or = 90, - OP_FUNC_CGRATensor_LogicOp_bV_bVbV_xor = 91, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and = 92, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or = 93, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor = 94, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_and_loop = 95, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_or_loop = 96, - OP_FUNC_CGRATensor_LogicOp_bV_bVubV_xor_loop = 97, - - // Transcendental Operator - OP_FUNC_CGRATensor_TransOp_V_V_log2 = 98, - OP_FUNC_CGRATensor_TransOp_V_V_ln = 99, - OP_FUNC_CGRATensor_TransOp_V_V_pow2 = 100, - OP_FUNC_CGRATensor_TransOp_V_V_exp = 101, - OP_FUNC_CGRATensor_TransOp_V_V_exp_lp = 102, - OP_FUNC_CGRATensor_TransOp_V_V_sin = 103, - OP_FUNC_CGRATensor_TransOp_V_V_cos = 104, - - // Activation Operator - OP_FUNC_CGRATensor_ActOp_V_V_tanh = 105, - OP_FUNC_CGRATensor_ActOp_V_V_sigmoid = 106, - OP_FUNC_CGRATensor_ActOp_V_V_relu = 107, - OP_FUNC_CGRATensor_ActOp_V_V_satrelu = 108, - OP_FUNC_CGRATensor_ActOp_V_V_leakyrelu = 109, - OP_FUNC_CGRATensor_ActOp_V_V_softplus = 110, - - // Reduce Operator - OP_FUNC_CGRATensor_ReduceOp_T_T_sum = 111, - OP_FUNC_CGRATensor_ReduceOp_T_T_avg = 112, - OP_FUNC_CGRATensor_ReduceOp_T_T_max = 113, - OP_FUNC_CGRATensor_ReduceOp_T_T_min = 114, - - // Pool Operator - OP_FUNC_CGRATensor_PoolOp_T_T_avg = 115, - OP_FUNC_CGRATensor_PoolOp_T_T_sum = 116, - OP_FUNC_CGRATensor_PoolOp_T_T_max = 117, - OP_FUNC_CGRATensor_PoolOp_T_T_indexedmax = 118, - OP_FUNC_CGRATensor_PoolOp_T_T_min = 119, - OP_FUNC_CGRATensor_PoolOp_T_T_indexedmin = 120, - - // DataMove - OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool = 121, - OP_FUNC_CGRATensor_DataMoveOp_T_T_unpool_avg = 122, - OP_FUNC_CGRATensor_DataMoveOp_T_T_maskunpool = 123, - // reshape - OP_FUNC_CGRATensor_DataMoveOp_T_T_mirror = 124, - OP_FUNC_CGRATensor_DataMoveOp_T_T_transpose = 125, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate90 = 126, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate180 = 127, - OP_FUNC_CGRATensor_DataMoveOp_T_T_rotate270 = 128, - OP_FUNC_CGRATensor_DataMoveOp_T_T_nchw2nhwc = 129, - OP_FUNC_CGRATensor_DataMoveOp_T_T_nhwc2nchw = 130, - OP_FUNC_CGRATensor_DataMoveOp_T_T_concat = 131, - OP_FUNC_CGRATensor_DataMoveOp_T_T_pad = 132, - OP_FUNC_CGRATensor_DataMoveOp_T_T_channelnorm = 133, - // datamove - OP_FUNC_CGRATensor_DataMoveOp_V_V_maskmove = 134, - OP_FUNC_CGRATensor_DataMoveOp_T_T_gatherscatter = 135, - OP_FUNC_CGRATensor_DataMoveOp_V_V_maskgather = 136, - OP_FUNC_CGRATensor_DataMoveOp_V_bV_maskgather = 137, - OP_FUNC_CGRATensor_DataMoveOp_T_T_img2col = 138, - - // Conver Operator - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp16 = 139, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_bf16 = 140, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_fp32 = 141, - OP_FUNC_CGRATensor_ConvertOp_V_V_int8_tf32 = 142, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp16 = 143, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_bf16 = 144, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_fp32 = 145, - OP_FUNC_CGRATensor_ConvertOp_V_V_int16_tf32 = 146, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp16 = 147, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_bf16 = 148, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_fp32 = 149, - OP_FUNC_CGRATensor_ConvertOp_V_V_int32_tf32 = 150, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int8 = 151, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int16 = 152, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_int32 = 153, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp16 = 154, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_fp32 = 155, - OP_FUNC_CGRATensor_ConvertOp_V_V_bf16_tf32 = 156, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int8 = 157, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int16 = 158, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_int32 = 159, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_bf16 = 160, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_fp32 = 161, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp16_tf32 = 162, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int8 = 163, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int16 = 164, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_int32 = 165, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_fp16 = 166, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_bf16 = 167, - OP_FUNC_CGRATensor_ConvertOp_V_V_fp32_tf32 = 168, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int8 = 169, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int16 = 170, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_int32 = 171, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp16 = 172, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_bf16 = 173, - OP_FUNC_CGRATensor_ConvertOp_V_V_tf32_fp32 = 174, - - // Peripheral Operator - OP_FUNC_CGRATensor_PeriOp_S_V_count = 175, - OP_FUNC_CGRATensor_PeriOp_S_bV_bitcount = 176, - OP_FUNC_CGRATensor_PeriOp_V_V_argmax = 177, - OP_FUNC_CGRATensor_PeriOp_V_V_argmin = 178, - OP_FUNC_CGRATensor_PeriOp_T_memset = 179, - OP_FUNC_CGRATensor_PeriOp_V_V_fp32_factorize = 180, - OP_FUNC_CGRATensor_PeriOp_V_V_bit2fp = 181, - OP_FUNC_CGRATensor_PeriOp_T_T_bilinear = 182, - OP_FUNC_CGRATensor_PeriOp_V_V_lut16 = 183, - OP_FUNC_CGRATensor_PeriOp_V_V_lut32 = 184, - OP_FUNC_CGRATensor_PeriOp_V_rand_gen = 185, - OP_FUNC_CGRATensor_PeriOp_V_V_elem_mask = 186, -} OP_FUNC_CGRA; - -typedef enum CGRA_INSTR_TYPE { - CGRA_INSTR_TYPE0, - CGRA_INSTR_TYPE1, - CGRA_INSTR_TYPE2, - CGRA_INSTR_TYPE3, -} CGRA_INSTR_TYPE; - -typedef struct Op_fu_head { - uint8_t fu; - uint8_t opcode; -} Op_fu_head; - -typedef struct FU_gemm_head { - uint8_t fu; - uint8_t gemm; -} FU_gemm_head; - -typedef struct opfunc_cgra_info { - char name[64]; // CGRATensor_ArithOp_V_V_abs - int32_t opcode; // 8'b0000_0000 - int32_t type; // CGRA_Tensor_type0 -} opfunc_cgra_info; - -// Neural -typedef enum Data_Format { - Fmt_INT8, - Fmt_INT16, - Fmt_FP16, - Fmt_BF16, - Fmt_INT32, - Fmt_FP32, - Fmt_TF32, - Fmt_BOOL, // 1/8 BYTE - Fmt_UINT8, - Fmt_UINT16, - Fmt_UINT32, - Fmt_INT64, - Fmt_UINT64, - Fmt_UNUSED, -} Data_Format; - -typedef enum Tensor_Fmt { - T_GemmM = 0, /*M K*/ - T_ConvA = 1, /*H W C*/ - T_ConvW = 2, /*Kx Ky F C*/ - T_Vec = 3, - T_ConvNA = 4, - T_ConvNW = 5, -} Tensor_Fmt; - -/* - 张量做SumReduce操作,支持以下维度: - C方向规约,结果为HW(C=1),dim=0 - W方向规约,结果为H(W=1)C,dim=1 - H方向规约,结果为(H=1)WC,dim=2 - HW方向规约,结果为(H=1)(W=1)C,dim=4 -*/ -typedef enum Reduce_Dim { - Reduce_C = 0, - Reduce_W = 1, - Reduce_H = 2, - Reduce_HW = 4, -} Reduce_Dim; - -typedef struct NCC_CSR { - uint64_t ib_status; //[7:0]IB_COUNTER: 指令buffer剩余指令数目, [8]TASK_DONE, - // 1:task执行结束, 0:task 正在执行, [63:9]Reserved - uint64_t - exception; //[7:0]SCALAR_EXCEPTION, [15:8]CT_EXCEPTION, - //[23:16]NE_EXCEPTION, [31:24]RDMA_EXCEPTION, - //[39:32]WDMA_EXCEPTION, [47:40]TDMA_EXCEPTION, [63:48]Reserved - uint64_t priority; //[7:0]PRIORITY,当前worker的优先级, [63:8]Reserved - uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, - //[49]EXCEPTION_CLEAR, [63:49]Reserved - uint64_t - serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved -} NCC_CSR; - -typedef struct EXCEP_SERI { - uint64_t exception_mask; //[47:0]EXCEPTION_MASK, [48]EXCEPTION_UPDATE_ENABLE, - //[49]EXCEPTION_CLEAR, [63:49]Reserved - uint64_t - serial_mode; //[0]SERIAL_MODE, 1:串行模式,0:并行模式, [63:1]Reserved -} EXCEP_SERI; -#endif diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h deleted file mode 100644 index 4ee0d0e30..000000000 --- a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_common.h +++ /dev/null @@ -1,488 +0,0 @@ -/* - * Copyright (C) 2024 Tsing Micro Intelligent Technology Co.,Ltd. All rights - * reserved. - * - * This file is the property of Tsing Micro Intelligent Technology Co.,Ltd. This - * file may only be distributed to: (i) a Tsing Micro party having a legitimate - * business need for the information contained herein, or (ii) a non-Tsing Micro - * party having a legitimate business need for the information contained herein. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - */ - -#ifndef __HOST_RUNTIME_COM_H__ -#define __HOST_RUNTIME_COM_H__ - -#include -#include -#include -#include -#include -#include -#include -#include - -#ifndef MAX_SHAPE_DIM -#define MAX_SHAPE_DIM 6 -#endif - -#ifndef MAX_MODEL_NUM -#define MAX_MODEL_NUM 32 -#endif - -typedef uint64_t TsmDevicePtr; -typedef uint64_t TsmHostPtr; - -#define CHIP_MAX_NUM 32 -#define TILE_MAX_NUM 16 -#define CACHE_ALIGN_4k 4096 - -typedef void *(*THREAD_PROC_FUNC)(void *); - -enum TSM_RETCODE { - RET_SUCCESS, - RET_ERROR, - RET_PARAM1_ERROR, - RET_PARAM2_ERROR, - RET_PARAM3_ERROR, - RET_DEVICE_OFFLINE, - RET_DEVICE_NOMEM, - RET_DEVICE_IN_IDLE, - RET_DEVICE_IN_ATTACH, - RET_DEVICE_ATTACH_SUCCESS, - RET_DEVICE_ATTACH_READY, - RET_DEVICE_LOSE_CONNECT, - RET_ENV_CLEAN_UP, -}; - -typedef enum HostLogLevel { - LOG_DEBUG, - LOG_INFO, - LOG_WARNING, - LOG_ERROR, - LOG_FATAL, - LOG_MAX -} HostLogLevel; - -typedef enum TsmModuleType { - TSM_RUNTIME, - TSM_XLA, // 前端 - TSM_TXNN, // 推理引擎 - TSM_ENGTEST, // 板端测试套件 - TSM_HOSTSIM, // 模拟器测试套件 - TSM_CMODEL, // 模拟器API - TSM_RT_TEST, // runtime组件测试套件 -} TsmModuleType; - -typedef enum TsmProfAction { TSM_PROF_START, TSM_PROF_STOP } TsmProfAction; -constexpr uint16_t PROF_TYPE_NCC = 0x1; -constexpr uint16_t PROF_TYPE_SPM = 0x2; -constexpr uint16_t PROF_TYPE_DTE = 0x4; - -typedef enum DTYPE { - FMT_INT8, - FMT_INT16, - FMT_FP16, - FMT_BF16, - FMT_INT32, - FMT_FP32, - FMT_TF32, - FMT_BOOL, // 1/8 BYTE - FMT_UINT8, - FMT_UINT16, - FMT_UINT32, - FMT_INT64, - FMT_UINT64, - FMT_UNUSED, -} DTYPE; - -uint8_t hrt_get_dtype_size(DTYPE dtype); - -enum DynDataType { - PKT_FINAL_TYPE = 0, - CFG_PMU_TYPE, - KCORE_CFG_TYPE, - EXPORT_SPM_TYPE, - DISABLE_CALC_TYPE, - PROF_CFG_TYPE, - DYNLIB_LOAD, - DYNLIB_RUN, - DYNLIB_UNLOAD, - MEMCPY_D2D, - P2P_SEND, - P2P_RECV, - DATA_TYPE_MAX, -}; - -typedef struct DynTLV_Terminate { - uint32_t type; // DynDataType - uint32_t len; - uint64_t is_final; -} DynTLV_Terminate; - -typedef struct DynTLV { - uint32_t type; // DynDataType - uint32_t len; -} DynTLV; - -typedef struct Cfg_Pmu_Info { - uint32_t tile_bitmap[16]; - uint32_t mac_use_rate; - uint32_t chip_id; - uint32_t cycles; - uint64_t in_ddr; - uint64_t param_ddr; - uint64_t out_ddr; - uint32_t reserved; -} Cfg_Pmu_Info; - -typedef struct DynTLV_Cfgpmu { - uint32_t type; // DynDataType - uint32_t len; - Cfg_Pmu_Info cfg_pmu; -} DynTLV_Cfgpmu; - -typedef struct DynTLV_KcoreCfg { - uint32_t type; - uint32_t len; - uint64_t snap_addr[TILE_MAX_NUM]; - uint64_t console_addr[TILE_MAX_NUM]; - uint64_t spm_dump_addr[TILE_MAX_NUM]; - uint64_t spm_dump_size; - uint32_t log_level; - uint32_t enable_monitor; -} DynTLV_KcoreCfg; - -typedef struct DynTLV_KcoreCalc { - uint32_t type; - uint32_t len; - uint32_t disable_kcore_calc; -} DynTLV_KcoreCalc; - -typedef struct DynTLV_ProfCfg { - uint32_t type; - uint32_t len; - uint64_t addrs[TILE_MAX_NUM]; - uint32_t size; - uint16_t enable; - uint16_t prof_type; -} DynTLV_ProfCfg; - -// #define TILE_NUM 16 -typedef struct DynModule { - char module_name[128]; - char module_symbol[128]; // typedef void (*entry_func_t)(voicd *): - uint32_t module_size[TILE_MAX_NUM]; - uint64_t module_addr[TILE_MAX_NUM]; // dev地址 -} DynModule; - -typedef struct DynMods { - uint16_t module_num; - struct DynModule modules[0]; -} DynMods; // host共用结构,传过来这个首地址 - -typedef struct DynTLV_DynMods { - uint32_t type; // DynDataType - uint32_t len; - uint64_t ext_addr; - uint64_t dyn_mods_addr; // 指向DynMods -} DynTLV_DynMods; - -typedef struct TileDteCfg { - uint16_t status; // 该tile是否参与搬运工作 - uint16_t remote_tile_id; // 对端tile_id - uint32_t element_count; // 单次搬运cache_line大小,默认4k - uint32_t stride; // 步长 - uint32_t left_element_count; // 搬完cache_line后,剩余的搬运的长度 - uint64_t iteration; // 搬运cache_line的次数 - uint64_t src_addr; // 搬运cache_line的源地址 - 物理 - uint64_t dst_addr; // 搬运cache_line的目的地址 - 物理 - uint64_t left_src_addr; // 搬运余数的源地址 - 物理 - uint64_t left_dst_addr; // 搬运余数的目的地址 - 物理 -} TileDteCfg; -typedef struct DynTLV_DteCfg { - uint32_t type; - uint32_t len; - TileDteCfg tile_dte_cfg[TILE_MAX_NUM]; - uint64_t barrier_addr; - uint32_t row_card_num; - uint32_t reserved; -} DynTLV_DteCfg; - -enum Tensor_Type { - INPUT_DATA, - OUTPUT_DATA, - PARAM_DATA, - CHACHE_DATA, - DEV_DDR_DATA, -}; - -typedef struct tensor_info { - int32_t inplace; - uint32_t dim; - uint32_t dtype; - uint32_t layout; - uint32_t shape[MAX_SHAPE_DIM]; -} tensor_info_t; - -typedef struct Json_common_info_t { - uint32_t input_num; - uint32_t output_num; - uint32_t param_num; - uint32_t tile_num; - - std::string case_name; - std::string card_name; - - std::vector> input; - std::vector> output; - - std::vector input_file; - std::vector output_file; - std::vector param_file; - - std::vector input_size; - std::vector output_size; - std::vector param_size; - uint64_t imm_size; - -} Json_common_info_t; - -typedef struct chip_common_info { - uint32_t input_num; - uint32_t output_num; - uint32_t param_num; - uint32_t tile_num; - uint32_t tile_x; - uint32_t tile_y; - std::vector> input; - std::vector> output; - - // char card_name[100]; - std::string card_name; - std::vector input_file; - std::vector output_file; - std::vector output_ref_file; - std::vector param_file; - - std::vector input_size; - std::vector output_size; - std::vector param_size; - - std::vector input_host_addr; - std::vector input_dev_addr; - std::vector output_host_addr; - std::vector output_dev_addr; - std::vector param_host_addr; - std::vector param_dev_addr; - - uint64_t imm_size; -} chip_common_info_t; - -typedef struct json_common_info_multi_card { - uint32_t chip_num; - uint32_t chip_x; - uint32_t chip_y; - std::string case_name; - uint32_t loop_num; - std::vector> chip_infos; -} json_common_info_multi_card_t; - -typedef struct CompileOption { - bool comp_enable = false; - std::string rtt_tool_path; - std::string compile_path; - bool check_enable = false; - uint32_t chip_x; - uint32_t chip_y; - bool enable_kcore_bin; - bool enable_kcore_so; -} CompileOption; - -// Boot Param Table -typedef struct BootParamHead { - uint32_t MaxLen; // BootParamHead + n * BootParamDyninfo, n = inputnum + - // outputnum + paramnum - uint32_t LdmemLen; - uint32_t InputNum; - uint32_t OutputNum; - uint32_t ParamNum; - uint32_t reserved; - uint64_t CacheMemLen; - uint64_t CacheMemAddr; // device - uint32_t Datalen; - uint32_t reserved1; - uint64_t DataAddr; // device -} BootParamHead; - -typedef struct BootParamDyninfo { - uint64_t addr; // device - uint64_t size; - uint32_t dtype; - uint32_t dim; - uint32_t shape[6]; // #define MAX_SHAPE_DIM 6 //n, h, w, c, x, x -} BootParamDyninfo; - -class HrtBootParam { -public: - HrtBootParam(uint32_t i_num, uint32_t o_num, uint32_t p_num) - : i_num(i_num), o_num(o_num), p_num(p_num) { - uint32_t bufsize = (sizeof(BootParamHead) + - (i_num + o_num + 1) * sizeof(BootParamDyninfo)); - buffer = (void *)malloc(bufsize); - memset(buffer, 0, bufsize); - BootParamHead *head = static_cast(buffer); - head->MaxLen = bufsize; - head->LdmemLen = 0x200000; - head->InputNum = i_num; - head->OutputNum = o_num; - head->ParamNum = p_num; - } - ~HrtBootParam() { - if (buffer != nullptr) { - free(buffer); - } - } - std::vector dyninfo; - uint32_t get_maxlen(); - void *get_bootpmbuffer(); - BootParamHead *get_headptr(); - BootParamDyninfo *get_inputptr(uint32_t index); - BootParamDyninfo *get_outputptr(uint32_t index); - BootParamDyninfo *get_paramptr(uint32_t index); - void set_dev_cache(uint64_t dev_addr, uint64_t size); - void set_dev_cache_mem_addr(uint64_t dev_addr, uint64_t size); - void set_dev_dyndata(uint64_t dev_addr, uint32_t size); - void set_dev_dyndata_mem_addr(uint64_t dev_addr, uint32_t size); - void set_dev_input(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_input_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_input_tensor(uint32_t idx, - std::shared_ptr tensor); - void set_dev_output(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_output_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_param(uint32_t idx, uint64_t dev_addr, uint64_t size); - void set_dev_param_mem_addr(uint32_t idx, uint64_t dev_addr, uint64_t size); - std::shared_ptr get_dev_output_tensor_after_run(uint32_t idx); - -private: - uint32_t i_num; - uint32_t o_num; - uint32_t p_num; - void *buffer; -}; -/* 启动参数end */ - -/* compiler后生成的存储elf和param地址的对象 */ -class HostParamElem { -public: - HostParamElem() : dataPtr(nullptr), size(0) {} - ~HostParamElem(); - // 模拟器:从文件中加载一个bin - HostParamElem(const std::string &filepath); - - uint8_t *loadBinaryFile(const std::string filepath, uint64_t &fsize); - uint8_t *dataPtr; // host - uint64_t size; // byte -}; - -class ChipModelInfo { -public: - ChipModelInfo(); - ChipModelInfo(uint32_t id); - ~ChipModelInfo(); - - uint32_t getChipId() { return chip_id; } - // support multi chip - std::vector> elfs; // 编译出的elf文件 - std::vector> bins; // 编译出的bin文件 - std::vector> params; - -private: - uint32_t chip_id; -}; - -/* - * compiler后生成的模型对象,在launch的时候会将elf/bin的指针传入soc的接口, - * (PCIE搬运时,如果空间不连续会触发多次搬运,因此交由SOC组装连续空间。) - */ -class TsmModel { -public: - TsmModel(); // org_model - ~TsmModel(); - TsmModel(const std::string &filepath); - - std::vector> chip_infos; - THREAD_PROC_FUNC proc_func; - std::string case_name; - std::string case_dir; - std::shared_ptr so_list[MAX_MODEL_NUM] - [TILE_MAX_NUM]; // 编译出的so文件 - std::string module_name; - struct txmodel *model[MAX_MODEL_NUM]; -}; - -typedef struct TsmDevice { - char res_path[128]; - uint32_t chip_id; - uint32_t tile_num = 16; - void *soc_device; -} TsmDevice_t; - -class TsmTensorData { -public: - TsmTensorData() : host_addr(0), device_addr(0), length(0) {} - ~TsmTensorData() {}; - - TsmHostPtr host_addr; - TsmDevicePtr device_addr; - uint64_t length; - uint32_t data_type; - Tensor_Type tensor_type; -}; - -typedef void *tsmStream_t; -typedef void *tsmEvent_t; -typedef struct txcclComm *txcclComm_t; -typedef enum { txcclDataDefault = 0 } txcclDataType_t; // 预留,待讨论 - -enum device_status { - FULLGOOD = 0, - PARTIALGOOD = 1, -}; - -constexpr uint32_t PARTIALGOOD_NUM = 8; -constexpr uint32_t FULLGOOD_NUM = 16; - -struct CardComputeInfo { - uint32_t card_id; - enum device_status device_status; - uint32_t all_tile_num; - double all_tile_compute; - uint32_t left_tile_num; - double left_tile_compute; -}; - -struct TsmDeviceInfo { - uint32_t card_num; - uint32_t card_x; - uint32_t card_y; - CardComputeInfo card_compute_info[CHIP_MAX_NUM]; -}; - -int32_t readDataFromFile(uint8_t *buffer, std::string file, uint32_t size); -uint8_t *read_file_data(std::string file, uint64_t &size); - -std::shared_ptr -get_multi_card_common_info_from_file(std::string file); -std::string get_docker_verison(); -TSM_RETCODE set_multi_graph(TsmModel *&kmodel, - std::shared_ptr &hostboot, - const TsmDevicePtr &dev_dyn_mods_ptr, - const TsmDevicePtr &dev_tlv_ptr, - TsmDevicePtr ext_ptr); -#endif /* __HOST_RUNTIME_COM_H__ */ diff --git a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h b/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h deleted file mode 100644 index ce726beca..000000000 --- a/third_party/tsingmicro/crt/include/Tx81/runtime/hrt_interface.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (C) 2024 Tsing Micro Intelligent Technology Co.,Ltd. All rights - * reserved. - * - * This file is the property of Tsing Micro Intelligent Technology Co.,Ltd. This - * file may only be distributed to: (i) a Tsing Micro party having a legitimate - * business need for the information contained herein, or (ii) a non-Tsing Micro - * party having a legitimate business need for the information contained herein. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - */ - -#ifndef __HOST_RUNTIME_INTERFACE_H__ -#define __HOST_RUNTIME_INTERFACE_H__ - -#include -#include - -#include "hrt_common.h" - -/* - * 以下接口依赖Runtime实例生命周期中,即调用TsmInitRuntime后,调用TsmDeInitRuntime前 - */ -TSM_RETCODE TsmInitRuntime(void); -TSM_RETCODE TsmDeInitRuntime(void); -TSM_RETCODE TsmDeInitRuntimeLegacy(void); -TSM_RETCODE TsmSetDevice(uint32_t first_phy_id, uint32_t card_x, - uint32_t card_y, std::vector &devs); -TSM_RETCODE TsmSetDeviceOld( - uint32_t chip_id, - TsmDevice *dev); /* 该接口为提供给MLIR的过度版本,其他组件不要调用 */ -TSM_RETCODE TsmDeviceMalloc(TsmDevice *dev, TsmDevicePtr &ptr, uint64_t size); -TSM_RETCODE TsmDeviceMemset(TsmDevicePtr &ptr, uint32_t ch, uint64_t size); -TSM_RETCODE TsmDeviceFree(TsmDevicePtr ptr); -TSM_RETCODE TsmDeviceSynchronize(TsmDevice *dev); -TSM_RETCODE TsmInitDevice(TsmDevice *dev); -TSM_RETCODE TsmCompile(std::vector devs, TsmModel &kmodel, - std::string option, CompileOption compl_op); -TSM_RETCODE TsmCompileMultiGraph(std::vector devs, - TsmModel &kmodel, std::string option, - CompileOption compl_op); -TSM_RETCODE TsmLaunch(TsmDevice *dev, TsmModel &kmodel); -TSM_RETCODE TsmLoadKernel(TsmDevice *dev, std::vector &kmodel_vec, - char *module_symbol); -TSM_RETCODE TsmUnloadKernel(TsmDevice *dev, - std::vector &kmodel_vec); -TSM_RETCODE TsmRun(TsmDevice *dev, TsmDevicePtr bootpm_dev); -TSM_RETCODE TsmAsyncRun(tsmStream_t stream, TsmDevice *dev, - TsmDevicePtr bootpm_dev); -TSM_RETCODE TsmSetTerminate(TsmDevice *dev, tsmStream_t stream = nullptr); -TSM_RETCODE TsmGetDeviceInfo(TsmDeviceInfo *info); -TSM_RETCODE TsmTerminate(TsmDevice *dev, TsmDevicePtr bootpm_dev); -TSM_RETCODE TsmMemcpyH2D(TsmDevicePtr dst, const void *src, - uint64_t byte_count); -TSM_RETCODE TsmMemcpyD2H(const void *dst, TsmDevicePtr src, - uint64_t byte_count); -TSM_RETCODE TsmMemcpyOffsetH2D(TsmDevicePtr dst, const void *src, - uint64_t offset, uint64_t byte_count); -TSM_RETCODE TsmMemcpyOffsetD2H(const void *dst, TsmDevicePtr src, - uint64_t offset, uint64_t byte_count); -TSM_RETCODE TsmMemcpyD2D(const void *dst, TsmDevice *dst_dev, const void *src, - TsmDevice *src_dev, uint64_t byte_count); -TSM_RETCODE TsmSend(const void *sendbuff, size_t count, - txcclDataType_t datatype, TsmDevice *dev, int peer, - txcclComm_t comm, tsmStream_t stream); -TSM_RETCODE TsmRecv(void *recvbuff, size_t count, txcclDataType_t datatype, - TsmDevice *dev, int peer, txcclComm_t comm, - tsmStream_t stream); -TSM_RETCODE TsmResetDevice(TsmDevice *dev); -TSM_RETCODE TsmReleaseDevice(TsmDevice *dev); -TSM_RETCODE TsmMemGetInfo(TsmDevicePtr ptr, uint32_t &card_id, uint64_t &addr, - uint64_t &size); -TSM_RETCODE TsmEventCreate(tsmEvent_t *pEvent); -TSM_RETCODE TsmEventDestroy(tsmEvent_t event); -TSM_RETCODE TsmEventRecord(tsmEvent_t event, tsmStream_t stream); -TSM_RETCODE TsmEventWait(tsmEvent_t event, tsmStream_t stream); -TSM_RETCODE TsmStreamCreate(tsmStream_t *pStream, TsmDevice *dev); -TSM_RETCODE TsmStreamSynchronize(tsmStream_t stream); -TSM_RETCODE TsmStreamDestroy(tsmStream_t stream); -TSM_RETCODE TsmDeviceSerialize(const TsmDevice *const &dev, void *&buffer, - size_t &size); -TSM_RETCODE TsmDeviceDeSerialize(TsmDevice *&dev, const void *const &buffer); -TSM_RETCODE TsmSetMonitorInfo(TsmDevice *dev); -TSM_RETCODE TsmProcessProfData(TsmDevice *dev, TsmProfAction prof_action, - uint16_t prof_type); -TSM_RETCODE TsmHostH2D(TsmDevice *dev, uint64_t input_host_addr, - uint64_t input_size, int32_t index); -TSM_RETCODE TsmHostFlush(TsmDevice *dev, uint64_t boot_param_ptr, - uint8_t *host_buffer, size_t size); -TSM_RETCODE TsmSetRankSize(uint32_t x_size, uint32_t y_size); -TSM_RETCODE TsmSetRankId(uint32_t x, uint32_t y); -TSM_RETCODE TsmGetPhyRankId(uint32_t *x, uint32_t *y); - -/* - * 以下接口为无状态,不依赖Runtime实例,可以独立使用 - */ -TSM_RETCODE TsmGetDeviceNum(uint32_t &dev_num); - -/* - * 为保持Host日志格式统一,Runtime提供了统一日志接口,各组件按以下方式使用: - * #define rt_log(level, format, ...) tsm_log(__FILE__, __func__, __LINE__, - * TSM_RUNTIME, level, format, ##__VA_ARGS__) - * - * void func() { - * rt_log(LOG_DEBUG, "....\n"); - * rt_log(LOG_INFO, "....\n"); - * rt_log(LOG_WARNING, "....\n"); - * rt_log(LOG_ERROR, "....\n"); - * } - * 默认日志级别为INFO,通过设置 HOST_LOG_LEVEL - * 更改日志级别,一般就设置成INFO和DEBUG。 注意: - * 其中rt_log为各组件定制名称,切勿重复,TSM_RUNTIME表示模块ID,各模块到hrt_common.h找到自己的宏,没有的可以联系runtime来增加。 - */ -void tsm_log(const char *file_name, const char *func_name, uint32_t line_number, - TsmModuleType module_type, HostLogLevel level, const char *format, - ...); -#endif diff --git a/third_party/tsingmicro/crt/include/Tx81/tx81.h b/third_party/tsingmicro/crt/include/Tx81/tx81.h index 6176349d2..b0af0b73d 100644 --- a/third_party/tsingmicro/crt/include/Tx81/tx81.h +++ b/third_party/tsingmicro/crt/include/Tx81/tx81.h @@ -9,6 +9,7 @@ #include "instr_adapter.h" #include "instr_def.h" +#include "lib_log.h" #include #include #include @@ -19,4 +20,14 @@ enum MemorySpace : int32_t { DDR = 2, }; +// Neural engine activate mode +enum ActFuncMode : int32_t { + None = 0, + ENRelu = 1, + ENLeakRelu = 2, +}; + +inline uint64_t spm_print_offset(uint64_t addr) { + return (uint64_t)addr + 0x030400000; +} #endif // CRT_TARGET_TX81_H diff --git a/third_party/tsingmicro/crt/lib/Tx81/abs.c b/third_party/tsingmicro/crt/lib/Tx81/abs.c new file mode 100644 index 000000000..3b79549b1 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/abs.c @@ -0,0 +1,32 @@ +//===------------------------- abs.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::AbsVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AbsVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AbsVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/arith.c b/third_party/tsingmicro/crt/lib/Tx81/arith.c index 9eb7ed3b8..a0040e82e 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/arith.c +++ b/third_party/tsingmicro/crt/lib/Tx81/arith.c @@ -186,3 +186,47 @@ void __DivVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, // Destroy the command buffer. TsmDeleteArith(cmd); } + +void __MaxVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE reserved, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MaxVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + reserved, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} + +void __MinVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + RND_MODE reserved, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->MinVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + reserved, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/count.c b/third_party/tsingmicro/crt/lib/Tx81/count.c index 855ed95e7..9f37bdaa0 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/count.c +++ b/third_party/tsingmicro/crt/lib/Tx81/count.c @@ -11,8 +11,7 @@ #include "tx81.h" -void __Count(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt, - uint64_t *p_wb_data0, uint64_t *p_wb_data1) { +void __Count(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Create command buffer. TsmPeripheral *cmd = TsmNewPeripheral(); TsmPeripheralInstr inst = {I_CGRA, @@ -24,8 +23,7 @@ void __Count(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt, }}; ; - cmd->Count(&inst, (uint64_t)src, (uint64_t)dst, elem_count, (Data_Format)fmt, - p_wb_data0, p_wb_data1); + cmd->Count(&inst, (uint64_t)src, elem_count, (Data_Format)fmt); // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/gemm.c b/third_party/tsingmicro/crt/lib/Tx81/gemm.c index a0f14d83a..ff0c5114f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gemm.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gemm.c @@ -12,12 +12,12 @@ #include "tx81.h" // The arguments list is aligned with TsmConv in Tx81Ops.td -void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *zeros, - int64_t *dims, bool enPsum, int64_t *psum, bool enTransA, +void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *dst, + int32_t *dims, bool enPsum, int64_t *psum, bool enTransA, bool enTransB, int64_t batchSizeA, int64_t batchSizeB, - bool enLeakyRelu, bool enBias, bool enNegScale, int64_t *negScale, - bool enPosScale, int64_t *posScale, int64_t srcFmt, int64_t dstFmt, - int64_t *dst) { + int32_t reluMode, bool enBias, bool enNegScale, int64_t *negScale, + bool enPosScale, int64_t *posScale, int64_t srcFmt, + int64_t dstFmt) { // Create gemm command buffer. TsmGemm *gemm = TsmNewGemm(); TsmNeInstr inst = {I_NEUR, @@ -40,10 +40,16 @@ void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *zeros, gemm->AddBias(&inst, enBias, (uint64_t)srcBias); gemm->SetNegativeAxisScale(&inst, enNegScale, (uint64_t)negScale); gemm->SetPositiveAxisScale(&inst, enPosScale, (uint64_t)posScale); - if (enLeakyRelu) - gemm->EnableLeakyRelu(&inst); - else + switch (reluMode) { + case ENRelu: gemm->EnableRelu(&inst); + break; + case ENLeakRelu: + gemm->EnableLeakyRelu(&inst); + break; + default: + break; + } // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/logic.c b/third_party/tsingmicro/crt/lib/Tx81/logic.c new file mode 100644 index 000000000..62e36b694 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/logic.c @@ -0,0 +1,78 @@ +//===------------------------ logic.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::LogicOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __AndVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->AndVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} + +void __OrVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->OrVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} + +void __XorVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmLogic *cmd = TsmNewLogic(); + TsmLogicInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->XorVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteLogic(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/memset.c b/third_party/tsingmicro/crt/lib/Tx81/memset.c index 86edde788..f21651138 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/memset.c +++ b/third_party/tsingmicro/crt/lib/Tx81/memset.c @@ -24,7 +24,19 @@ void __Memset(uint64_t *dst, uint32_t value, uint32_t elem_count, uint32_t s0, 0, }}; - St_StrideIteration si = {s0, i0, s1, i1, s2, i2}; + // TODO: Use real stride and iteration, now accumulate all data to elem_count + int stride0 = 0; + int stride1 = 0; + int stride2 = 0; + + int iteration0 = 1; + int iteration1 = 1; + int iteration2 = 1; + + elem_count *= i0 * i1 * i2; + + St_StrideIteration si = {stride0, iteration0, stride1, + iteration1, stride1, iteration2}; cmd->Memset(&inst, (uint64_t)dst, value, elem_count, &si, (Data_Format)fmt); // Dispatch the command to accelerator diff --git a/third_party/tsingmicro/crt/lib/Tx81/rdma.c b/third_party/tsingmicro/crt/lib/Tx81/rdma.c index a72f964c2..f000df052 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rdma.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rdma.c @@ -15,6 +15,10 @@ void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, int shape_c, int stride_n, int stride_h, int stride_w, uint32_t fmt) { + // Dynamic shape, kernel implementation will cause shape equal to 0 + if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + return; + // Create gemm command buffer. TsmRdma *rdma = TsmNewRdma(); TsmRdmaInstr inst = {I_RDMA, @@ -26,13 +30,9 @@ void __Rdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, }}; rdma->AddSrcDst(&inst, (uint64_t)src, (uint64_t)dst, (Data_Format)fmt); - rdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, shape_h, stride_n, shape_n); - // rdma->Rdma1d(&inst, (uint64_t)src, (uint64_t)dst, shape_c, - // (Data_Format)fmt); - // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/crt/lib/Tx81/relation.c b/third_party/tsingmicro/crt/lib/Tx81/relation.c new file mode 100644 index 000000000..10069deb6 --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/relation.c @@ -0,0 +1,144 @@ +//===------------------------ relation.c-----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RelationOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __BoolEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolUnEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolUnEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolGreaterVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolGreaterVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessEqualVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} + +void __BoolLessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, + uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmRelation *cmd = TsmNewRelation(); + TsmRelationInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->BoolLessThenVV(&inst, (uint64_t)src0, (uint64_t)src1, (uint64_t)dst, + elem_count, (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteRelation(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c new file mode 100644 index 000000000..d0468966d --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c @@ -0,0 +1,34 @@ +//===------------------------ rsqrt.c -------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::RsqrtVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __RsqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, + uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->RsqrtVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/sqrt.c b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c new file mode 100644 index 000000000..5701513cb --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c @@ -0,0 +1,33 @@ +//===------------------------ sqrt.c --------------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::SqrtVVOp see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +#include "tx81.h" + +void __SqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { + // Create command buffer. + TsmArith *cmd = TsmNewArith(); + TsmArithInstr inst = {I_CGRA, + { + 0, + }, + { + 0, + }}; + + cmd->SqrtVV(&inst, (uint64_t)src, (uint64_t)dst, elem_count, + (Data_Format)fmt); + + // Dispatch the command to accelerator + TsmExecute(&inst); + + // Destroy the command buffer. + TsmDeleteArith(cmd); +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/wdma.c b/third_party/tsingmicro/crt/lib/Tx81/wdma.c index 1fee20152..93bfe6e89 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/wdma.c +++ b/third_party/tsingmicro/crt/lib/Tx81/wdma.c @@ -15,6 +15,11 @@ void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, int shape_c, int stride_n, int stride_h, int stride_w, uint32_t fmt) { + + // Dynamic shape, kernel implementation will cause shape equal to 0 + if (shape_n == 0 || shape_h == 0 || shape_w == 0 || shape_c == 0) + return; + // Create gemm command buffer. TsmWdma *wdma = TsmNewWdma(); TsmWdmaInstr inst = {I_WDMA, @@ -30,9 +35,6 @@ void __Wdma(uint64_t *src, uint64_t *dst, int shape_n, int shape_h, int shape_w, wdma->ConfigStrideIteration(&inst, shape_c, stride_w, shape_w, stride_h, shape_h, stride_n, shape_n); - // wdma->Wdma1d(&inst, (uint64_t)src, (uint64_t)dst, shape_c, - // (Data_Format)fmt); - // Dispatch the command to accelerator TsmExecute(&inst); diff --git a/third_party/tsingmicro/examples/bare_matmul.py b/third_party/tsingmicro/examples/bare_matmul.py new file mode 100644 index 000000000..84b9c9a87 --- /dev/null +++ b/third_party/tsingmicro/examples/bare_matmul.py @@ -0,0 +1,52 @@ +# this is a benchmark which multiplies square matrices with maximum block size +# to check the performance of tl.dot operation + +import torch +import triton +import triton.language as tl +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): + pid_x = tl.program_id(0) # block row id + pid_y = tl.program_id(1) # block column id + + offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) + y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) + + z = tl.dot(x, y) + + tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) + + +# @benchmark.measure() +def bench_matmul(N, provider): + device = 'cpu' + dtype = torch.float32 + a = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) + b = torch.randint(0, 10, (N, N), dtype=torch.int32).to(dtype) + # a = torch.randn((N, N), device=device, dtype=dtype) + # b = torch.randn((N, N), device=device, dtype=dtype) + c = torch.empty((N, N), device=device, dtype=dtype) + if provider == 'torch' or provider == 'test': + c_ref = torch.matmul(a, b) + # print("====cref:",c_ref) + if provider == 'triton' or provider == 'test': + bare_matmul[(1, )](a, b, c, N, N, N, N) + if provider == 'test': + torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0) + print("expected", c_ref) + print("actual", c) + print("======test====") + + +if __name__ == "__main__": + # benchmark.select_cpu_backend() + for provider in ['test']: + bench_matmul(16, provider) diff --git a/third_party/tsingmicro/examples/benchmark.py b/third_party/tsingmicro/examples/benchmark.py new file mode 100644 index 000000000..4a2284dd2 --- /dev/null +++ b/third_party/tsingmicro/examples/benchmark.py @@ -0,0 +1,65 @@ +import time +import numpy as np +from functools import wraps +import triton + +# Unfortunately, we can't use triton.testing.perf_report and triton.testing.do_bench for CPU backend because +# they are very specific to cuda + + +def measure(repeats=20, percentiles=(), timers={'Wall': time.perf_counter, 'CPU': time.process_time}): + """ + Decorator to benchmark a function. + + Parameters: + - repeats (int): The number of times the function should be executed for each set of parameters. + - percentiles (tuple): The percentiles to compute on the execution times (e.g., (50, 90, 99)). + - timers (dict): A dictionary where keys are timer names (e.g., 'Wall', 'CPU') and values are timer functions + that measure elapsed time. By default: + * 'Wall': Uses time.perf_counter for high-resolution wall-clock time. + * 'CPU': Uses time.process_time for CPU time spent by the process. + + Returns: + - A decorated function that prints: + * Average execution time. + * Standard deviation time. + * Minimum and maximum times. + * Computed percentiles for each timer. + """ + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + print(f"{func.__name__}{args} {kwargs}, {repeats} times, all results in seconds") + times = {} + for t, _ in timers.items(): + times[t] = [] + + for _ in range(repeats): + starts = {} + for t, f in timers.items(): + starts[t] = f() + + result = func(*args, **kwargs) + + for t, f in timers.items(): + times[t].append(f() - starts[t]) + + for t, _ in timers.items(): + average_time = np.mean(times[t]) + min_time = np.min(times[t]) + max_time = np.max(times[t]) + computed_percentiles = np.percentile(times[t], percentiles) + std_dev_time = np.std(times[t]) + + print(f"{t}: Avg={average_time:.6f}, min={min_time:.6f}, std={std_dev_time:.6f},", end=" ") + for p, value in zip(percentiles, computed_percentiles): + print(f"{p}pp={value:.6f},", end=" ") + print(f"max={max_time:.6f}") + + return result + + return wrapper + + return decorator diff --git a/third_party/tsingmicro/examples/test_vec_add.py b/third_party/tsingmicro/examples/test_vec_add.py new file mode 100644 index 000000000..b75f5aa42 --- /dev/null +++ b/third_party/tsingmicro/examples/test_vec_add.py @@ -0,0 +1,90 @@ +import torch + +import triton +import triton.language as tl +import benchmark + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + # assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + + +def test(device): + # torch.manual_seed(0) + size = 1024 + x = torch.rand(size, device="cpu") + y = torch.rand(size, device="cpu") + output_torch = x + y + x = x.to(device) + y = y.to(device) + output_triton = add(x, y) + # TODO: need to check some conditions otherwise the code below does not make any difference for the test + print("expected", output_torch) + output_triton = output_triton.to("cpu") + print("actual", output_triton) + print(f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(output_torch - output_triton))}") + + +@benchmark.measure() +def bench_vecadd(size, provider): + a = torch.rand(size, device='cpu', dtype=torch.float32) + b = torch.rand(size, device='cpu', dtype=torch.float32) + if provider == 'torch': + a + b + if provider == 'triton': + a = a.to(DEVICE) + b = b.to(DEVICE) + add(a, b) + + +if __name__ == "__main__": + # test(DEVICE) + for X in [2**i for i in range(8, 25, 1)]: + for provider in ['torch', 'triton']: + bench_vecadd(X, provider) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h index 6e9b8147a..96173716c 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.h @@ -20,6 +20,9 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" +// Declear spmPointer. +extern uint64_t spmPointer; + namespace mlir { namespace triton { diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt index f8257f56b..626484155 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -1,3 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name Tx81ToLLVM) add_public_tablegen_target(Tx81ToLLVMConversionPassIncGen) + +set(LLVM_TARGET_DEFINITIONS KernelArgBufferPass.td) +mlir_tablegen(KernelArgBufferPass.h.inc -gen-pass-decls --name KernelArgBufferPass) +add_public_tablegen_target(KernelArgBufferPassIncGen) diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h index 3ee5ebdef..f4de9dcaf 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h @@ -20,12 +20,16 @@ namespace mlir { class ModuleOp; class Pass; +namespace triton { /// Creates a pass that transforms kernel functions by replacing multiple /// arguments with a single void* buffer argument. std::unique_ptr createKernelArgBufferPass(); -#define GEN_PASS_DECL_KERNELARGBUFFERPASS -#include "KernelArgBufferPass.h.inc" +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h.inc" + +} // namespace triton } // namespace mlir #endif // MLIR_KERNEL_ARG_BUFFER_PASS_H diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td index 0d63527ab..a47c45d07 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.td @@ -25,7 +25,7 @@ def KernelArgBufferPass : Pass<"kernel-arg-buffer", "ModuleOp"> { Where the args buffer contains pointers to arg1 and arg2, followed by the scalar values size, gridX, and x. Each scalar value occupies 8 bytes in the buffer. }]; - let constructor = "mlir::createKernelArgBufferPass()"; + let constructor = "mlir::triton::createKernelArgBufferPass()"; let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::func::FuncDialect"]; } diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td index 476817a79..71899adf6 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -65,7 +65,7 @@ def RdmaOp : Tx81Op<"rdma", [ ins MemRefOrInt:$source, // The source address in DDR MemRefOrInt:$target, // The target address in SPM - Variadic:$shape, // HHWC shape + Variadic:$shape, // NHWC shape Variadic:$strides, // 3 dim strides I32Attr:$fmt ); @@ -99,7 +99,7 @@ def WdmaOp : Tx81Op<"wdma", [ ins MemRefOrInt:$source, // The source address in DDR MemRefOrInt:$target, // The target address in SPM - Variadic:$shape, // HHWC shape + Variadic:$shape, // NHWC shape Variadic:$strides, // 3 dim strides I32Attr:$fmt ); @@ -189,9 +189,9 @@ def GemmOp : Tx81Op<"gemm", []> { MemRefOrInt:$src_a, // Input matrix A addr in SPM MemRefOrInt:$src_b, // Input matrix B addr in SPM MemRefOrInt:$src_bias, // The address of bias in SPM - // Zeroes buffer which can be used to fill $dst + // Output and initial zeroes buffer // FIXME: Whether need add side effect to source operands? - Arg:$zeroes, + Arg:$dst, I32ArrayAttr:$dims, // The dimensions of M, K, N BoolAttr:$en_psum, // Enable psum? TODO: Production sum? MemRefOrInt:$psum_addr, // The address of psum in SPM, TODO: psum? @@ -199,7 +199,7 @@ def GemmOp : Tx81Op<"gemm", []> { BoolAttr:$trans_src_b, // Should matrix B be transposed I32Attr:$batch_src_a, // The batch of matrix A I32Attr:$batch_src_b, // The batch of matrix B - BoolAttr:$en_leaky_relu,// Enable LeakyRelu or normal Relu + I32Attr:$relu_mode, // Enable LeakyRelu or normal Relu or none BoolAttr:$en_bias, // Enable bias add BoolAttr:$en_neg_scale, // Enable negative axis scale MemRefOrInt:$src_neg_scale, // The address of negative scale data in SPM @@ -211,14 +211,30 @@ def GemmOp : Tx81Op<"gemm", []> { ); // Output matrix C addr in SPM - let results = (outs Variadic:$dst); + let results = (outs Variadic:$output); } // ============================================================================= // 4.10. TsmArith // ============================================================================= -def AbsVVOp : Tx81Op<"absvv", [Pure, Elementwise]> {} +class UnaryOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input, // Input vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AbsVVOp : UnaryOp<"absvv"> { + let summary = "Absolute value of input vector"; +} +def SqrtVVOp : UnaryOp<"sqrtvv", [Pure, Elementwise]> {} +def RsqrtVVOp : UnaryOp<"rsqrtvv", [Pure, Elementwise]> {} +def NegVVOp : UnaryOp<"negvv", [Pure, Elementwise]> {} def RecipVVOp : Tx81Op<"recipvv", [Pure, Elementwise]> {} def SquareVVOp : Tx81Op<"squarevv", [Pure, Elementwise]> {} @@ -241,6 +257,8 @@ def AddVVOp : BinaryVVOp<"addvv"> { def SubVVOp : BinaryVVOp<"subvv">; def MulVVOp : BinaryVVOp<"mulvv">; def DivVVOp : BinaryVVOp<"divvv">; +def MaxVVOp : BinaryVVOp<"maxvv">; +def MinVVOp : BinaryVVOp<"minvv">; class BinaryVSOp traits = []> : Tx81Op { @@ -265,38 +283,97 @@ def DivVSOp : BinaryVSOp<"divvs">; // ... // ============================================================================= -// 4.13. TsmTranscendental +// 4.11. TsmRelation // ============================================================================= -class TranscendentalOp traits> : - Tx81Op { +class BoolRelationVVOp traits = []> : + Tx81Op { let arguments = (ins - MemRefOrInt:$src, // Input vector address - I32Attr:$elem_count, // Number of input elements - I16Attr:$fmt // The data format of src & dst + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst ); - let results = (outs I64:$dst); + let results = (outs Variadic:$dst); } -def Log2 : TranscendentalOp<"log2", []> { +def BoolEqualVV : BoolRelationVVOp<"boolequalvv"> { + let summary = "compare two input value, if equal, return true"; +} + +def BoolUnEqualVV : BoolRelationVVOp<"boolunequalvv"> { + let summary = "compare two input value, if unequal, return true"; +} + +def BoolGreaterEqualVV : BoolRelationVVOp<"boolgreatrequalvv"> { + let summary = "compare two input value, if src0 >= src1, return true"; +} + +def BoolGreaterVV : BoolRelationVVOp<"boolgreatervv"> { + let summary = "compare two input value, if src0 > src1, return true"; +} + +def BoolLessEqualVV : BoolRelationVVOp<"boollessequalvv"> { + let summary = "compare two input value, if src0 <= src1, return true"; +} + +def BoolLessThenVV : BoolRelationVVOp<"boollessthenvv"> { + let summary = "compare two input value, if src0 < src1, return true"; +} + +// ... +// ============================================================================= +// 4.12. TsmLogic +// ============================================================================= + +class BinaryLogicVVOp traits = []> : + Tx81Op { + let arguments = (ins + MemRefOrInt:$input0, // First input vector address + MemRefOrInt:$input1, // Second vector address + Arg:$out, // Out vector address + MemRefOrInt:$elem_count, // Number of input elements + I16Attr:$fmt // The data format of src & dst + ); + let results = (outs Variadic:$dst); +} + +def AndVV : BinaryLogicVVOp<"andvv"> { + let summary = "And operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +def OrVV : BinaryLogicVVOp<"orvv"> { + let summary = "OR operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +def XorVV : BinaryLogicVVOp<"xorvv"> { + let summary = "XOR operation on elements at the same position. If the element is not 0, it is represented as 1."; +} + +// ============================================================================= +// 4.13. TsmTranscendental +// ============================================================================= + +def Log2Op : UnaryOp<"log2", []> { let summary = "Logarithm based 2"; } -def Ln : TranscendentalOp<"ln", []> { +def LnOp : UnaryOp<"ln", []> { let summary = "Logarithm based e"; } -def Pow2 : TranscendentalOp<"pow2", []> { +def Pow2Op : UnaryOp<"pow2", []> { let summary = "2 ** x"; } -def Exp : TranscendentalOp<"exp", []> { +def ExpOp : UnaryOp<"exp", []> { let summary = "Exponential with high precision"; } -def Explp : TranscendentalOp<"explp", []> { +def ExplpOp : UnaryOp<"explp", []> { let summary = "Exponential with low precision"; } -def Sin : TranscendentalOp<"sin", []> { +def SinOp : UnaryOp<"sin", []> { let summary = "Sine"; } -def Cos : TranscendentalOp<"cos", []> { +def CosOp : UnaryOp<"cos", []> { let summary = "Cosine"; } @@ -558,19 +635,33 @@ def CountOp : Tx81Op<"count", [Pure]> { let results = (outs MemRefOrInt:$dst); } -def MemsetOp : Tx81Op<"memset", []> { +def MemsetOp : Tx81Op<"memset", [ + AttrSizedOperandSegments, + PredOpTrait<"Constrain shape to 4d.", + CPred<"cast($_op).getShape().size() == 4">>, + PredOpTrait<"Constrain strides to 3d.", + CPred<"cast($_op).getStrides().size() == 3">> + ]> { let summary = "Write given `value` to range of address on SPM(sram)"; let arguments = ( ins - MemRefOrInt:$src, // SPM address to be memset - I32:$value, // Value to be written - AnySignlessIntegerOrIndex:$elem_count, - I32ArrayAttr:$strides, - I32ArrayAttr:$iterations, + MemRefOrInt:$src, // SPM address to be memset + I32:$value, // Value to be written + Variadic:$shape, // NHWC shape + Variadic:$strides, // 3 dim strides I16Attr:$fmt ); + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, "Value":$value, + "ArrayRef":$shape, + "ArrayRef":$strides, + "IntegerAttr":$fmt + )> + ]; + // The address updated by memset in SPM let results = (outs MemRefOrInt:$dst); } @@ -634,8 +725,7 @@ def Lut16Op : Tx81Op<"lut16", []> { let summary = "16-bit lookup table"; let arguments = (ins - // FIXME: AnyVector is not defined - // AnyVector:$src, // Vector offset with respect to LUT + MemRefOrInt:$src, // Vector offset with respect to LUT UI64:$lut16, I32Attr:$src_elem_count, // Number of elements in vector offset I32Attr:$lut_elem_count // Number of elements in LUT @@ -648,8 +738,7 @@ def Lut32Op : Tx81Op<"lut32", []> { let summary = "32-bit lookup table"; let arguments = (ins - // FIXME: AnyVector is not defined - // AnyVector:$src, // Vector offset with respect to LUT + MemRefOrInt:$src, // Vector offset with respect to LUT UI64:$lut32, I32Attr:$src_elem_count, // Number of elements in vector offset I32Attr:$lut_elem_count // Number of elements in LUT diff --git a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp index a5efbc096..aa4904bdf 100644 --- a/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp +++ b/third_party/tsingmicro/lib/Analysis/OpFoldResultUtils.cpp @@ -16,13 +16,8 @@ namespace mlir { std::optional getIntAttr(const OpFoldResult ofr) { - // Check if ofr is an Attribute - if (auto attr = dyn_cast(ofr)) { - // Check if it's specifically an IntegerAttr - if (auto intAttr = dyn_cast(attr)) { - return intAttr.getInt(); - } - } + if (isa(ofr) && isa(cast(ofr))) + return dyn_cast(cast(ofr)).getInt(); return std::nullopt; } diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp index 356d16cf6..3b51faf43 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -12,7 +12,7 @@ //===----------------------------------------------------------------------===// #include "tsingmicro-tx81/Conversion/MKToTx81/MKToTx81.h" -#include "Tx81/instr_def.h" +#include "Tx81/tx81.h" #include "magic-kernel/Dialect/IR/MagicKernelDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -69,32 +69,23 @@ class MKToTx81TypeConverter : public TypeConverter { // Get format code for tensor element type // This maps MLIR types to Tx81 format codes -Data_Format getFormatCode(Type type) { - if (type.isF32()) { +Data_Format getFormatCode(MemRefType type) { + auto elemType = type.getElementType(); + if (elemType.isF32()) { return Fmt_FP32; - } else if (type.isF16()) { + } else if (elemType.isF16()) { return Fmt_FP16; - } else if (type.isBF16()) { + } else if (elemType.isBF16()) { return Fmt_BF16; - } else if (type.isInteger(8)) { + } else if (elemType.isInteger(8)) { return Fmt_INT8; + } else { + llvm_unreachable("Tx8 unsupported the element type\n"); } - // Default to F32 format return Fmt_FP32; } -// Get element count from shape -int32_t getElementCount(ArrayRef shape) { - int32_t elementCount = 1; - for (auto dim : shape) { - if (dim > 0) { - elementCount *= dim; - } - } - return elementCount; -} - // Helper function to extract shape from tensor type SmallVector getShapeFromTensorType(TensorType type) { SmallVector shape; @@ -106,25 +97,39 @@ SmallVector getShapeFromTensorType(TensorType type) { // Helper function to extract dimensions from memref or tensor type SmallVector getDimsFromType(Type type) { SmallVector dims; - if (auto memrefType = dyn_cast(type)) { + if (auto memrefType = mlir::dyn_cast(type)) { for (auto dim : memrefType.getShape()) dims.push_back(static_cast(dim)); - } else if (auto tensorType = dyn_cast(type)) { + } else if (auto tensorType = mlir::dyn_cast(type)) { for (auto dim : tensorType.getShape()) dims.push_back(static_cast(dim)); } return dims; } -Value createAddressFromMemref(ConversionPatternRewriter &rewriter, Location loc, - Value memref) { +static uint64_t getElemByte(Type type) { + static DataLayout dataLayout; + auto typeSize = dataLayout.getTypeSize(type); + if (!typeSize.isFixed()) { + llvm::llvm_unreachable_internal("All element type should have fixed size."); + } + return typeSize.getFixedValue(); +} + +static Value createAddressFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value memref) { auto stridedMetadata = rewriter.create(loc, memref); Value indexBasePtr = rewriter.create( loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + auto elemType = mlir::cast(memref.getType()).getElementType(); + Value elemByte = + rewriter.create(loc, getElemByte(elemType)); Value offset = stridedMetadata.getOffset(); + Value byteOffset = + rewriter.create(loc, offset.getType(), offset, elemByte); Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), - indexBasePtr, offset); + indexBasePtr, byteOffset); Value i64SPMPtr = rewriter.create( loc, rewriter.getI64Type(), offsetPtr); return i64SPMPtr; @@ -137,9 +142,14 @@ createMetadata(ConversionPatternRewriter &rewriter, Location loc, rewriter.create(loc, operand); Value indexBasePtr = rewriter.create( loc, rewriter.getIndexType(), stridedMetadata.getBaseBuffer()); + auto elemType = mlir::cast(operand.getType()).getElementType(); + Value elemByte = + rewriter.create(loc, getElemByte(elemType)); Value offset = stridedMetadata.getOffset(); + Value byteOffset = + rewriter.create(loc, offset.getType(), offset, elemByte); Value offsetPtr = rewriter.create(loc, indexBasePtr.getType(), - indexBasePtr, offset); + indexBasePtr, byteOffset); Value i64SPMPtr = rewriter.create( loc, rewriter.getI64Type(), offsetPtr); @@ -147,8 +157,8 @@ createMetadata(ConversionPatternRewriter &rewriter, Location loc, return {i64SPMPtr, stridedMetadata.getSizes(), stridedMetadata.getStrides()}; } -static SmallVector padSizesToNHWC(ConversionPatternRewriter &rewriter, - Location loc, ValueRange sizes) { +static SmallVector padSizesToNHWC(ConversionPatternRewriter &rewriter, + Location loc, ValueRange sizes) { Value one = rewriter.create(loc, 1); int numPad = 4 - sizes.size(); SmallVector nhwcShape; @@ -162,8 +172,9 @@ static SmallVector padSizesToNHWC(ConversionPatternRewriter &rewriter, } // The last stride is always 1, skip it, nhwcStrides.size() will be 3. -static SmallVector padStridesToNHWC(ConversionPatternRewriter &rewriter, - Location loc, ValueRange strides) { +static SmallVector +padStridesToNHWC(ConversionPatternRewriter &rewriter, Location loc, + ValueRange strides) { Value one = rewriter.create(loc, 1); int numPad = 4 - strides.size(); SmallVector nhwcStrides; @@ -179,6 +190,11 @@ static SmallVector padStridesToNHWC(ConversionPatternRewriter &rewriter, static Value calculateElemCount(ConversionPatternRewriter &rewriter, Location loc, ValueRange sizes) { + // If we get scalar data, sizes is empty, return 1 + if (sizes.empty()) { + return rewriter.create(loc, 1); + } + Value elemCount = sizes[0]; for (int i = 1; i < sizes.size(); i++) { elemCount = rewriter.create(loc, elemCount.getType(), @@ -194,11 +210,79 @@ template llvm::SmallVector getRegionOps(T linalgOp) { [](Operation &op) { return &op; }); } +// Convert integer type to float type for CGRA instruction +// Return the convert float type format code +// TODO: Directly convert memref type? +Data_Format insertConvertTypeOp(Value valuePtr, MemRefType valueType, + Value elemCount, + ConversionPatternRewriter &rewriter, + Location loc) { + + // TODO: Other integer type. May need realloc the memory + auto elemType = valueType.getElementType(); + + if (!isa(elemType)) + return getFormatCode(valueType); + + Data_Format fmt = Fmt_FP32; + // Get the bit width from the element type + auto bitWidth = elemType.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 16: { // 16 bit integer + rewriter.create(loc, rewriter.getI64Type(), valuePtr, + valuePtr, elemCount); + fmt = Fmt_FP16; + break; + } + case 32: { // 32 bit integer + rewriter.create(loc, rewriter.getI64Type(), valuePtr, + valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + default: { + llvm_unreachable("Unsupported integer type\n"); + } + } + return fmt; +} + +// Restore float type to integer type to for CGRA instruction +Value insertRestoreTypeOp(Value valuePtr, MemRefType valueType, Value elemCount, + ConversionPatternRewriter &rewriter, Location loc) { + // TODO: Other integer type. May need realloc the memory + auto elemType = valueType.getElementType(); + auto newValue = valuePtr; + if (!isa(elemType)) + return newValue; + + // Get the bit width from the element type + auto bitWidth = elemType.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 16: { // 16 bit integer + newValue = rewriter.create( + loc, rewriter.getI64Type(), valuePtr, valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + case 32: { // 32 bit integer + newValue = rewriter.create( + loc, rewriter.getI64Type(), valuePtr, valuePtr, elemCount, + rewriter.getI16IntegerAttr(0)); + break; + } + default: { + llvm_unreachable("Unsupported integer type\n"); + } + } + return newValue; +} + class MemoryCopyConvertPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - // Workaround: Avoid analyzing control flow as much as possible。 + // Workaround: Avoid analyzing control flow as much as possible bool isOperandMemorySpaceSPM(Value operand) const { while (auto op = operand.getDefiningOp()) { @@ -212,8 +296,10 @@ class MemoryCopyConvertPattern : public OpConversionPattern { LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - bool isSrcSPM = isOperandMemorySpaceSPM(adaptor.getSource()); - bool isDstSPM = isOperandMemorySpaceSPM(adaptor.getTarget()); + assert(op->hasAttr("srcSpm") && op->hasAttr("dstSpm") && + "Can't get memory space attribute\n"); + bool isSrcSPM = mlir::cast(op->getAttr("srcSpm")).getInt(); + bool isDstSPM = mlir::cast(op->getAttr("dstSpm")).getInt(); // DDR to DDR if (!isSrcSPM && !isDstSPM) @@ -237,10 +323,9 @@ class MemoryCopyConvertPattern : public OpConversionPattern { rewriter.create( op->getLoc(), rewriter.getI64Type(), srcPtr, constValue, dstPtr, - elemCount, // Element count - rewriter.getI16IntegerAttr(0), // Round mode - rewriter.getI16IntegerAttr(getFormatCode( - inputType)) // Format (5 = f32, assuming f32 for now) + elemCount, // Element count + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format ); } else if (isDstSPM) { auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), srcSizes); @@ -248,10 +333,9 @@ class MemoryCopyConvertPattern : public OpConversionPattern { auto rdmaOp = rewriter.create( op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, - nhwcShape, // NHWC shape - nhwcStrides, // NHWC stride - rewriter.getI32IntegerAttr(getFormatCode( - inputType)) // Format (5 = f32, assuming f32 for now) + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format ); } else { auto nhwcShape = padSizesToNHWC(rewriter, op->getLoc(), dstSizes); @@ -259,10 +343,9 @@ class MemoryCopyConvertPattern : public OpConversionPattern { auto wdmaOp = rewriter.create( op.getLoc(), rewriter.getI64Type(), srcPtr, dstPtr, - nhwcShape, // NHWC shape - nhwcStrides, // NHWC stride - rewriter.getI32IntegerAttr(getFormatCode( - inputType)) // Format (5 = f32, assuming f32 for now) + nhwcShape, // NHWC shape + nhwcStrides, // NHWC stride + rewriter.getI32IntegerAttr(getFormatCode(inputType)) // Format ); } @@ -285,30 +368,50 @@ class LinalgFillOpConversion : public OpConversionPattern { if (op.getOutputs().size() != 1) return rewriter.notifyMatchFailure(op, "Only support single output\n"); - // Convert the fill value to int64 - if (fillValue.getType().isF32()) { - // If it's a float constant, bitcast it to int - fillValue = rewriter.create( + auto [srcPtr, srcSizes, srcStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto inputType = op.getInputs()[0].getType(); + auto bitWidth = op.getInputs()[0].getType().getIntOrFloatBitWidth(); + assert(bitWidth == 16 || + bitWidth == 32 && "Only support 16/32 fill value\n"); + + // AddVS value need has fmt with input fmt and only support float type + Data_Format fmt = bitWidth == 16 ? Fmt_FP16 : Fmt_FP32; + + if (inputType.isInteger()) { + auto floatType = + bitWidth == 16 ? rewriter.getF16Type() : rewriter.getF32Type(); + fillValue = + rewriter.create(op.getLoc(), floatType, fillValue); + } + + auto bitcastType = + bitWidth == 16 ? rewriter.getI16Type() : rewriter.getI32Type(); + fillValue = + rewriter.create(op.getLoc(), bitcastType, fillValue); + + if (bitWidth == 16) { + fillValue = rewriter.create( op.getLoc(), rewriter.getI32Type(), fillValue); - } else if (fillValue.getType().isF16()) { - auto extf = rewriter.create( - op.getLoc(), rewriter.getF32Type(), fillValue); - fillValue = rewriter.create( - op.getLoc(), rewriter.getI32Type(), extf); } - auto [srcPtr, srcSizes, srcStrides] = - createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + // TODO: For scalar data, instead of function call, we should convert + // linalg.fill to memref.store directly to get better performance. + + // Use xor + addvs to simulate memset operation. Only support type fp32 and + // fp16 + // 1. xor srcPtr with itself to get zero + // 2. addvs srcPtr with value to get the fill value auto elemCount = calculateElemCount(rewriter, op->getLoc(), srcSizes); - // Create a MemsetOp to fill the SPM buffer - // TODO: Support format code for different element types - auto memsetOp = rewriter.create( - op.getLoc(), rewriter.getI64Type(), srcPtr, fillValue, elemCount, - rewriter.getI32ArrayAttr({}), // Strides (empty for simple fill) - rewriter.getI32ArrayAttr({}), // Iterations (empty for simple fill) - rewriter.getI16IntegerAttr(5) // Format (5 = f32, assuming f32 for now) - ); + auto init = + rewriter.create(op.getLoc(), rewriter.getI64Type(), srcPtr, + srcPtr, srcPtr, elemCount, fmt); + auto resultOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), srcPtr, fillValue, srcPtr, + elemCount, + rewriter.getI16IntegerAttr(0), // round_mode + rewriter.getI16IntegerAttr(fmt)); rewriter.eraseOp(op); @@ -322,6 +425,21 @@ class LinalgFillOpConversion : public OpConversionPattern { class MKDotToTx81GemmOpConversion : public OpConversionPattern { + + void fp32ToTF32(ConversionPatternRewriter &rewriter, Location loc, + ValueRange sizes, Value spmAddr) const { + // Warning for neural engine that fp32 is not supported + llvm::errs() + << "\nNeural engine not support FP32. Convert FP32 to TF32 for " + "tx.Gemm Op\n"; + auto elemCount = calculateElemCount(rewriter, loc, sizes); + rewriter.create( + loc, rewriter.getI64Type(), spmAddr, spmAddr, + elemCount, // element_count + rewriter.getI16IntegerAttr(0) // round_mode + ); + } + public: using OpConversionPattern::OpConversionPattern; @@ -329,8 +447,14 @@ class MKDotToTx81GemmOpConversion matchAndRewrite(mlir::mk::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Extract dimensions from tensor types - MemRefType aTensorType = cast(op.getA().getType()); - MemRefType bTensorType = cast(op.getB().getType()); + MemRefType aTensorType = mlir::cast(op.getA().getType()); + MemRefType bTensorType = mlir::cast(op.getB().getType()); + assert(aTensorType.getElementType() == bTensorType.getElementType() && + "a and b must have the same element type"); + MemRefType zeroTensorType = + mlir::cast(op.getZeroes().getType()); + Data_Format srcFmt = getFormatCode(aTensorType); + Data_Format dstFmt = getFormatCode(zeroTensorType); // Get converted operands auto loc = op.getLoc(); @@ -347,37 +471,48 @@ class MKDotToTx81GemmOpConversion auto dims = rewriter.getI32ArrayAttr({M, K, N}); // Get operand ptr - auto a = createAddressFromMemref(rewriter, loc, adaptor.getA()); - auto b = createAddressFromMemref(rewriter, loc, adaptor.getB()); - auto c = createAddressFromMemref(rewriter, loc, adaptor.getC()); - auto zeros = createAddressFromMemref(rewriter, loc, adaptor.getZeroes()); + auto [aPtr, aSizes, aStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getA()); + auto [bPtr, bSizes, bStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getB()); + auto [cPtr, cSizes, cStrides] = + createMetadata(rewriter, op->getLoc(), adaptor.getC()); + // Assume input type is same. Tx neural engine not support fp32 for input + if (aTensorType.getElementType().isF32()) { + srcFmt = Data_Format::Fmt_TF32; + fp32ToTF32(rewriter, op->getLoc(), aSizes, aPtr); + fp32ToTF32(rewriter, op->getLoc(), bSizes, bPtr); + fp32ToTF32(rewriter, op->getLoc(), cSizes, cPtr); + } + + auto dst = createAddressFromMemref(rewriter, loc, adaptor.getZeroes()); + + auto zero = rewriter.create(op.getLoc(), 0, + rewriter.getI64Type()); // Create GemmOp rewriter.create( op.getLoc(), rewriter.getI64Type(), - a, // src_a (Matrix A in SPM) - b, // src_b (Matrix B in SPM) - c, // src_bias (optional accumulation) - zeros, // zeroes, + aPtr, // src_a (Matrix A in SPM) + bPtr, // src_b (Matrix B in SPM) + cPtr, // src_bias (optional accumulation) + dst, // dst, dims, // dimensions [M,K,N] rewriter.getBoolAttr(false), // en_psum - zeros, // WORKAROUND: psum_addr (using zeroes buffer) - rewriter.getBoolAttr(false), // trans_src_a - rewriter.getBoolAttr(false), // trans_src_b - rewriter.getI32IntegerAttr(1), // batch_src_a - rewriter.getI32IntegerAttr(1), // batch_src_b - rewriter.getBoolAttr(false), // en_leaky_relu - rewriter.getBoolAttr(op.getC() != nullptr), // en_bias - rewriter.getBoolAttr(false), // en_neg_scale - rewriter - .create(op.getLoc(), 0, rewriter.getI64Type()) - .getResult(), // src_neg_scale - rewriter.getBoolAttr(false), // en_pos_scale - rewriter - .create(op.getLoc(), 0, rewriter.getI64Type()) - .getResult(), // src_pos_scale - rewriter.getI32IntegerAttr(3), // src_fmt (3 = f32) - rewriter.getI32IntegerAttr(3) // dst_fmt (3 = f32) + dst, // WORKAROUND: psum_addr (using dst buffer) + rewriter.getBoolAttr(false), // trans_src_a + // NOTE: (N, K) is thought not trans in hardware + rewriter.getBoolAttr(true), // trans_src_b. + rewriter.getI32IntegerAttr(1), // batch_src_a + rewriter.getI32IntegerAttr(1), // batch_src_b + rewriter.getI32IntegerAttr(ActFuncMode::None), // relu_mode. + rewriter.getBoolAttr(op.getC() != nullptr), // en_bias + rewriter.getBoolAttr(false), // en_neg_scale + zero, // src_neg_scale + rewriter.getBoolAttr(false), // en_pos_scale + zero, // src_pos_scale + rewriter.getI32IntegerAttr(srcFmt), // src_fmt + rewriter.getI32IntegerAttr(dstFmt) // dst_fmt ); // Op has no result value rewriter.eraseOp(op); @@ -389,6 +524,32 @@ class MKDotToTx81GemmOpConversion struct ElementwiseConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + template + LogicalResult convertUnaryOp(linalg::GenericOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input = createAddressFromMemref(rewriter, loc, adapter.getInputs()[0]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adapter.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + auto outputType = dyn_cast(op.getOutputs()[0].getType()); + // Data format after conversion + Data_Format srcFmt = + insertConvertTypeOp(input, inputType, elemCount, rewriter, loc); + Data_Format dstFmt = + insertConvertTypeOp(output, outputType, elemCount, rewriter, loc); + // Create the unary operation + rewriter.create(loc, rewriter.getI64Type(), input, output, elemCount, + rewriter.getI16IntegerAttr(srcFmt)); + insertRestoreTypeOp(input, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(output, outputType, elemCount, rewriter, loc); + + rewriter.eraseOp(op); + return success(); + } + template LogicalResult convertBinaryOp(linalg::GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -402,14 +563,23 @@ struct ElementwiseConversion : public OpConversionPattern { auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); auto inputType = dyn_cast(op.getInputs()[0].getType()); + // Data format after conversion + Data_Format srcFmt = + insertConvertTypeOp(input0, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(input1, inputType, elemCount, rewriter, loc); + insertConvertTypeOp(output, inputType, elemCount, rewriter, loc); + // Create the elementwise operation // TODO: Fix attribute - rewriter.create( - loc, rewriter.getI64Type(), input0, input1, output, elemCount, - rewriter.getI16IntegerAttr(0), // Round mode - rewriter.getI16IntegerAttr( - getFormatCode(inputType)) // Format (5 = f32, assuming f32 for now) - ); + rewriter.create(loc, rewriter.getI64Type(), input0, input1, output, + elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(srcFmt)); + + insertRestoreTypeOp(input0, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(input1, inputType, elemCount, rewriter, loc); + insertRestoreTypeOp(output, inputType, elemCount, rewriter, loc); + rewriter.eraseOp(op); return success(); } @@ -450,6 +620,58 @@ struct ElementwiseConversion : public OpConversionPattern { return success(); } + template + LogicalResult BoolRelationVVOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + // Create the elementwise operation + // TODO: Fix attribute + rewriter.create( + loc, rewriter.getI64Type(), input0, input1, output, elemCount, + rewriter.getI16IntegerAttr(getFormatCode(inputType)) // Format + ); + + rewriter.eraseOp(op); + return success(); + } + + LogicalResult FmaConvertOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto input0 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[0]); + auto input1 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[1]); + auto input2 = + createAddressFromMemref(rewriter, loc, adaptor.getInputs()[2]); + auto [output, sizes, strides] = + createMetadata(rewriter, op->getLoc(), adaptor.getOutputs()[0]); + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto inputType = dyn_cast(op.getInputs()[0].getType()); + + auto mulResult = rewriter.create( + loc, rewriter.getI64Type(), input0, input1, output, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType))); + auto addResult = rewriter.create( + loc, rewriter.getI64Type(), output, input2, output, elemCount, + rewriter.getI16IntegerAttr(0), // Round mode + rewriter.getI16IntegerAttr(getFormatCode(inputType))); + rewriter.eraseOp(op); + return success(); + } + LogicalResult matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -460,7 +682,16 @@ struct ElementwiseConversion : public OpConversionPattern { if (op.getIteratorTypesArray().front() != utils::IteratorType::parallel) return rewriter.notifyMatchFailure(op, "Only support elementwise op."); + if (regionOps.size() != 1) { + if (failed(linalg::linalgOpToLoops(rewriter, op))) + return rewriter.notifyMatchFailure(op, + "Element-wise op not yet supported"); + rewriter.eraseOp(op); + return success(); + } + auto elemWiseOp = regionOps[0]; + auto resultType = elemWiseOp->getResult(0).getType(); return llvm::TypeSwitch(elemWiseOp) .Case([&](auto elemWiseOp) { return convertBinaryOp(op, adaptor, rewriter); @@ -475,15 +706,53 @@ struct ElementwiseConversion : public OpConversionPattern { [&](auto elemWiseOp) { return convertBinaryOp(op, adaptor, rewriter); }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertBinaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) + .Case([&](auto elemWiseOp) { + return convertUnaryOp(op, adaptor, rewriter); + }) .Case([&](auto elemWiseOp) { return NormalConvertOp(op, adaptor, rewriter); }) + .Case([&](auto elemWiseOp) { + return FmaConvertOp(op, adaptor, rewriter); + }) .Case([&](auto elemWiseOp) { // TODO: Need add more int to fp convert. - auto inputType = - cast(op.getInputs()[0].getType()).getElementType(); - auto outputType = - cast(op.getOutputs()[0].getType()).getElementType(); + auto inputType = mlir::cast(op.getInputs()[0].getType()) + .getElementType(); + auto outputType = mlir::cast(op.getOutputs()[0].getType()) + .getElementType(); if (inputType.isInteger(16) && outputType.isF32()) { return RoundConvertOp(op, adaptor, rewriter); } else if (inputType.isInteger(16) && outputType.isF16()) { @@ -500,10 +769,10 @@ struct ElementwiseConversion : public OpConversionPattern { }) .Case([&](auto elemWiseOp) { // TODO: Need add more int to fp convert. - auto inputType = - cast(op.getInputs()[0].getType()).getElementType(); - auto outputType = - cast(op.getOutputs()[0].getType()).getElementType(); + auto inputType = mlir::cast(op.getInputs()[0].getType()) + .getElementType(); + auto outputType = mlir::cast(op.getOutputs()[0].getType()) + .getElementType(); if (inputType.isF16() && outputType.isInteger(8)) { return RoundConvertOp(op, adaptor, rewriter); } else if (inputType.isF16() && outputType.isInteger(16)) { @@ -522,12 +791,50 @@ struct ElementwiseConversion : public OpConversionPattern { "integer conversion"); } }) +// FIXME: Now BoolLessThenOp run fail on board. Need more op information from +// Tx81 +#if 0 + .Case([&](auto elemWiseOp) { + arith::CmpIPredicate predicate = elemWiseOp.getPredicate(); + switch (predicate) { + case arith::CmpIPredicate::eq: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::ne: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::sge: + return BoolRelationVVOp(op, adaptor, + rewriter); + case arith::CmpIPredicate::sgt: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::sle: + return BoolRelationVVOp(op, adaptor, rewriter); + case arith::CmpIPredicate::slt: + return BoolRelationVVOp(op, adaptor, rewriter); + default: + llvm_unreachable("Not yet supported"); + break; + } + }) +#endif + .Case([&](auto elemWiseOp) { + if (resultType.isF16()) + return RoundConvertOp(op, adaptor, rewriter); + else if (resultType.isBF16()) + return RoundConvertOp(op, adaptor, rewriter); + else + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for trunc " + "conversion"); + }) .Default([&](auto elemWiseOp) { // WORKAROUND: Used to handle tl.arange(0, BLOCK_SIZE) which will // lower to linalg.generic + linalg.index + arith.index_cast and // other unsupported case now (eg: arith::extf) // TODO: Lower ops to tx81 if is supported - if (failed(linalg::linalgOpToAffineLoops(rewriter, op))) + + // Affine dialect should handled before this pass. So here lower it to + // scf.for + if (failed(linalg::linalgOpToLoops(rewriter, op))) return rewriter.notifyMatchFailure( op, "Element-wise op not yet supported"); rewriter.eraseOp(op); diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp index aa0c51160..371c2faeb 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81Pass.cpp @@ -43,8 +43,55 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { mk::MagicKernelDialect, tx::Tx81Dialect>(); } + bool isOperandMemorySpaceSPM(Value operand) { + Operation *lastOp = operand.getDefiningOp(); + Operation *op = lastOp; + + do { + if (isa(op)) + return true; + else if (auto forOp = dyn_cast(op)) { + // Here we assume that yieldResults (inner loop region) and + // loopResults (outer loop region) correspond one-to-one to obtain the + // inner loop region definingOp of the outer loop region value. + // FIXME: Need reference the standard loop analysis to refactor this. + + auto yieldResults = forOp.getYieldedValues(); + mlir::ResultRange loopResults = forOp.getLoopResults().value(); + assert(yieldResults.size() == loopResults.size()); + + auto idx = std::distance( + loopResults.begin(), + std::find(loopResults.begin(), loopResults.end(), operand)); + operand = yieldResults[idx]; + + } else { + operand = op->getOperand(0); + } + lastOp = op; + op = operand.getDefiningOp(); + } while (op); + return false; + } + void runOnOperation() override { auto moduleOp = getOperation(); + + // Use to memory::CopyOp to tx dialect op + moduleOp->walk([&](Operation *op) { + if (isa(op)) { + auto copyOp = cast(op); + op->setAttr("srcSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isOperandMemorySpaceSPM( + copyOp.getSource())))); + op->setAttr("dstSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isOperandMemorySpaceSPM( + copyOp.getTarget())))); + } + }); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -66,6 +113,22 @@ class MKToTx81Pass : public triton::impl::MKToTx81Base { if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); } + + // linalg::linalgOpToLoops will generate memref::LoadOp/memref::StoreOp + // before and after the arith calculation. + // Use to check whether add spm mapping offset in + // memref::LoadOp/memref::StoreOp lowering + moduleOp->walk([&](Operation *op) { + if (isa(op)) { + bool isSpm = isa(op) + ? isOperandMemorySpaceSPM(op->getOperand(0)) + : isOperandMemorySpaceSPM(op->getOperand(1)); + + op->setAttr("isSpm", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), + llvm::APInt(32, isSpm))); + } + }); } }; diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp index cecd8e7b0..637732fc7 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -45,7 +45,8 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( RewritePatternSet &patterns) { if (pidsToFuncArgs) { - patterns.add( + // Need use tx interface to get pid. + patterns.add( patterns.getContext()); } if (addptrToLinalg) { diff --git a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index 530c3242b..bae1bd6ba 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -150,7 +150,8 @@ class TritonArithToLinalgPass }); if (pidsToFuncArgs) { - target.addIllegalOp(); + target + .addIllegalOp(); } if (addptrToLinalg) { diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp index 7f45b79c4..a857fbb3e 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVM.cpp @@ -20,9 +20,11 @@ using namespace mlir; #define GEN_PASS_CLASSES #include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h.inc" -namespace { +// Used for allocate spm memory +uint64_t spmPointer = 0x10000; -// Used to kcore load/store data from/to spm +namespace { +// Used for kcore load/store data from/to spm const int64_t spmMappingOffset = 0x30400000; //===----------------------------------------------------------------------===// @@ -37,8 +39,6 @@ struct TsmMemRefAllocOpLowering : public AllocLikeOpLLVMLowering { std::tuple allocateBufferFromSPM(ConversionPatternRewriter &rewriter, Location loc, Operation *op) const { - static uint64_t spmPointer = 0x10000; - // create GEPOp for spm address. MemRefType memRefType = getMemRefResultType(op); Value spmOffsetOp = rewriter.create( @@ -99,28 +99,42 @@ struct MemrefLoadOrStoreOpLowering : public ConvertOpToLLVMPattern { Value ptrValue = rewriter.create(op.getLoc(), intPtrType, dataPtr); - // FIXME: Can only need create once since offset is a const op? - auto spmMemoryOffset = rewriter.create( - op.getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(spmMappingOffset)); - auto spmMemoryAddr = rewriter.create( - op.getLoc(), rewriter.getI64Type(), - SmallVector({ptrValue, spmMemoryOffset})); - - auto ptrTy = LLVM::LLVMPointerType::get( - rewriter.getContext(), - *ConvertToLLVMPattern::getTypeConverter()->getMemRefAddressSpace(type)); - auto spmMemoryAddrPtr = - rewriter.create(op.getLoc(), ptrTy, spmMemoryAddr); + // Workaround: Should add memory space analysis pass. + Operation *opBase = op; + if (!opBase->hasAttr("isSpm")) { + return rewriter.notifyMatchFailure( + op, "Load/Store should have isSpm attribute."); + } + int isSpm = + cast(opBase->getAttr("isSpm")).getValue().getSExtValue(); + + Value adjustedPtr = dataPtr; + if (isSpm) { + auto spmMemoryOffset = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(spmMappingOffset)); + auto spmMemoryAddr = rewriter.create( + op.getLoc(), rewriter.getI64Type(), + SmallVector({ptrValue, spmMemoryOffset})); + + auto ptrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), + *ConvertToLLVMPattern::getTypeConverter()->getMemRefAddressSpace( + type)); + auto spmMemoryAddrPtr = + rewriter.create(op.getLoc(), ptrTy, spmMemoryAddr); + + adjustedPtr = spmMemoryAddrPtr; + } // Wether need memoryspace cast if constexpr (std::is_same()) { - rewriter.replaceOpWithNewOp( - op, op.getType(), spmMemoryAddrPtr, 0, false, op.getNontemporal()); + rewriter.replaceOpWithNewOp(op, op.getType(), adjustedPtr, + 0, false, op.getNontemporal()); } else { rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), dataPtr, 0, false, op.getNontemporal()); + op, adaptor.getValue(), adjustedPtr, 0, false, op.getNontemporal()); } return success(); @@ -199,7 +213,7 @@ struct MemRefReinterpretCastOpLowering // Create descriptor. Location loc = castOp.getLoc(); - MemRefDescriptor desc(*descriptor); + auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; diff --git a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp index d37c2824a..7eaf1da0d 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81MemrefToLLVM/Tx81MemrefToLLVMPass.cpp @@ -78,6 +78,11 @@ class Tx81MemrefToLLVMPass if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); } + + // Record spm usage. + moduleOp->setAttr("triton_tsm.spm_use", + mlir::IntegerAttr::get( + mlir::IntegerType::get(context, 32), spmPointer)); } }; diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt index 68517188e..441858a94 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/CMakeLists.txt @@ -1,8 +1,10 @@ add_triton_library(Tx81ToLLVM Tx81ToLLVM.cpp + KernelArgBufferPass.cpp DEPENDS Tx81ToLLVMConversionPassIncGen + KernelArgBufferPassIncGen MLIRMemRefToLLVM LINK_LIBS PUBLIC diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp index a024e67c8..9458fdc95 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/KernelArgBufferPass.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -20,10 +21,19 @@ using namespace mlir; +namespace mlir { +namespace triton { +#define GEN_PASS_CLASSES +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h.inc" +} // namespace triton +} // namespace mlir + namespace { class KernelArgBufferPass - : public PassWrapper> { + : public mlir::triton::KernelArgBufferPassBase { + using KernelArgBufferPassBase::KernelArgBufferPassBase; + public: StringRef getArgument() const final { return "kernel-arg-buffer"; } StringRef getDescription() const final { @@ -37,135 +47,32 @@ class KernelArgBufferPass void runOnOperation() override; private: - // Identifies if a function should be processed - bool isKernelFunction(func::FuncOp func); - - // Creates a new function with a single void* argument - func::FuncOp createBufferizedFunction(OpBuilder &builder, - func::FuncOp originalFunc); - - // Rewrites the function body to use the argument buffer - void rewriteFunctionBody(func::FuncOp originalFunc, func::FuncOp newFunc); + // Insert load op to get real kernel args from new buffered argument + // Side effect: calculate offset and create ops + Value insertKernelArgLoad(OpBuilder &builder, Location loc, Value argsBuffer, + Type argType, int64_t ¤tOffset); }; -bool KernelArgBufferPass::isKernelFunction(func::FuncOp func) { - // For this example, we'll identify kernel functions by their name - // containing "_kernel". In a real implementation, you might use attributes - // or more sophisticated detection. - return func.getName().contains("_kernel"); -} - -func::FuncOp -KernelArgBufferPass::createBufferizedFunction(OpBuilder &builder, - func::FuncOp originalFunc) { - // Create a new function type with a single void* argument - auto voidPtrType = LLVM::LLVMPointerType::get(builder.getContext()); - auto newFuncType = - FunctionType::get(originalFunc.getContext(), {voidPtrType}, - originalFunc.getFunctionType().getResults()); - - // Create the new function with the same name but new type - auto newFunc = func::FuncOp::create(originalFunc.getLoc(), - originalFunc.getName(), newFuncType); - - // Copy over all attributes except those related to the function type - for (const auto &attr : originalFunc->getAttrs()) { - if (attr.getName() != "function_type" && attr.getName() != "arg_attrs" && - attr.getName() != "res_attrs") { - newFunc->setAttr(attr.getName(), attr.getValue()); - } - } - - return newFunc; -} - -void KernelArgBufferPass::rewriteFunctionBody(func::FuncOp originalFunc, - func::FuncOp newFunc) { - if (originalFunc.empty()) - return; - - Block &oldEntryBlock = originalFunc.getBlocks().front(); - Block &newEntryBlock = newFunc.getBlocks().front(); - - OpBuilder builder(&newEntryBlock, newEntryBlock.begin()); - Location loc = originalFunc.getLoc(); - - Value argsBuffer = newEntryBlock.getArgument(0); - SmallVector extractedArgs; - - // Offset tracking for buffer access - int64_t currentOffset = 0; - // Size of scalar values in bytes (specified as 8 bytes) - const int64_t scalarSize = 8; - - // Process each original argument - for (auto argIndex : llvm::seq(0, originalFunc.getNumArguments())) { - Type argType = originalFunc.getArgument(argIndex).getType(); - Value loadedArg; - - // Handle pointer types (like uint64_t*) - if (auto ptrType = dyn_cast(argType)) { - // For pointer types, we load the pointer value itself from the buffer - auto offsetValue = builder.create( - loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); - - // Get pointer to the current position in args buffer - auto elementPtr = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), argsBuffer, - ArrayRef{offsetValue}); - - // Cast to pointer-to-pointer type - auto castedPtr = builder.create( - loc, LLVM::LLVMPointerType::get(ptrType), elementPtr); - - // Load the pointer - loadedArg = builder.create(loc, castedPtr); - - // Increment offset (pointers are 8 bytes) - currentOffset += scalarSize; - } - // Handle scalar types (like int64_t, int) - else { - auto offsetValue = builder.create( - loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); - - // Get pointer to the current position in args buffer - auto elementPtr = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), argsBuffer, - ArrayRef{offsetValue}); - - // Cast to appropriate pointer type - auto castedPtr = builder.create( - loc, LLVM::LLVMPointerType::get(argType), elementPtr); - - // Load the scalar value - loadedArg = builder.create(loc, castedPtr); - - // Increment offset (all scalars use 8 bytes as specified) - currentOffset += scalarSize; - } - - extractedArgs.push_back(loadedArg); - } - - // Clone the original function body, replacing uses of old arguments - auto &oldRegion = originalFunc.getBody(); - auto &newRegion = newFunc.getBody(); - - // Move operations from old entry block to new entry block - for (auto &op : oldEntryBlock.getOperations()) { - if (&op == &oldEntryBlock.back() && op.hasTrait()) { - builder.clone(op); - } else { - auto clonedOp = builder.clone(op); - - // Replace uses of old arguments with new extracted values - for (unsigned i = 0; i < originalFunc.getNumArguments(); ++i) { - Value oldArg = oldEntryBlock.getArgument(i); - clonedOp->replaceUsesOfWith(oldArg, extractedArgs[i]); - } - } - } +Value KernelArgBufferPass::insertKernelArgLoad(OpBuilder &builder, Location loc, + Value argsBuffer, Type argType, + int64_t ¤tOffset) { + // Get pointer to the current position in args buffer + auto offsetValue = builder.create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(currentOffset)); + + // NOTE: GEPOp need distinguish the scalar and ptr type. So here ptr + offset + Value elementPtr = + builder.create(loc, builder.getI64Type(), argsBuffer); + elementPtr = builder.create(loc, builder.getI64Type(), + elementPtr, offsetValue); + elementPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), elementPtr); + + // Increment offset. Assume all args are 8 bytes + currentOffset += sizeof(int64_t); + + // Load the real kernel arg value + return builder.create(loc, argType, elementPtr); } void KernelArgBufferPass::runOnOperation() { @@ -173,38 +80,64 @@ void KernelArgBufferPass::runOnOperation() { OpBuilder builder(module.getContext()); // Collect functions to process - SmallVector kernelFuncs; - for (auto func : module.getOps()) { - if (isKernelFunction(func)) { - kernelFuncs.push_back(func); - } + SmallVector kernelFuncs; + for (auto func : module.getOps()) { + kernelFuncs.push_back(func); } + // NOTE: We move this pass before tx81-to-llvm pass. + // So we assume the func op must be only one and must be the triton kernel + assert(kernelFuncs.size() == 1 && "Only one kernel function expected"); // Process each kernel function + // TODO: Delete the for loop if the assert is always true for all examples for (auto func : kernelFuncs) { // Create new function with bufferized signature builder.setInsertionPointAfter(func); - auto newFunc = createBufferizedFunction(builder, func); - - // Add entry block to the new function - newFunc.addEntryBlock(); - - // Rewrite function body to use the argument buffer - rewriteFunctionBody(func, newFunc); - - // Replace the old function with the new one - func.erase(); + // Save the old block arguments + SmallVector blockArguments = + llvm::to_vector<8>(func.getArguments()); + auto numArguments = blockArguments.size(); + + // New bufferized arg type + auto voidPtrType = LLVM::LLVMPointerType::get(builder.getContext()); + + // New bufferized function type + auto newFuncType = LLVM::LLVMFunctionType::get( + func.getFunctionType().getReturnType(), voidPtrType); + func.setFunctionType(newFuncType); + SmallVector newArgAttrs({DictionaryAttr()}); + func.setAllArgAttrs(newArgAttrs); + + // Add the new bufferized argument + Location loc = func.getLoc(); + Block &entryBlock = func.getBlocks().front(); + entryBlock.insertArgument((unsigned)0, voidPtrType, func.getLoc()); + + OpBuilder builder(&entryBlock, entryBlock.begin()); + // Get the bufferized argument + Value argsBuffer = entryBlock.getArgument(0); + + // Offset tracking for buffer access + int64_t currentOffset = 0; + + // Process each original argument + for (auto argIndex : llvm::seq(0, numArguments)) { + auto oldArg = blockArguments[argIndex]; + Type argType = oldArg.getType(); + Value loadedArg = insertKernelArgLoad(builder, func.getLoc(), argsBuffer, + argType, currentOffset); + + if (blockArguments[argIndex].use_empty()) + continue; + oldArg.replaceAllUsesWith(loadedArg); + } + // Remove the old arguments when replace the use-chain + entryBlock.eraseArguments(1, numArguments); } } } // namespace -std::unique_ptr createKernelArgBufferPass() { +std::unique_ptr triton::createKernelArgBufferPass() { return std::make_unique(); } - -// Pass registration -namespace { -#define GEN_PASS_REGISTRATION -#include "KernelArgBufferPass.h.inc" -} // namespace diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp index 667b0a8fa..68c7e75ca 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -36,6 +36,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -55,6 +56,15 @@ const char addVVFuncName[] = "__AddVV"; const char subVVFuncName[] = "__SubVV"; const char mulVVFuncName[] = "__MulVV"; const char divVVFuncName[] = "__DivVV"; +const char absVVFuncName[] = "__AbsVV"; +const char rsqrtVVFuncName[] = "__RsqrtVV"; +const char sqrtVVFuncName[] = "__SqrtVV"; +const char lnFuncName[] = "__Ln"; +const char log2FuncName[] = "__Log2"; +const char expFuncName[] = "__Exp"; +const char pow2FuncName[] = "__Pow2"; +const char sinFuncName[] = "__Sin"; +const char cosFuncName[] = "__Cos"; const char addVSFuncName[] = "__AddVS"; const char subVSFuncName[] = "__SubVS"; const char mulVSFuncName[] = "__MulVS"; @@ -73,6 +83,20 @@ const char fp16ToInt32FuncName[] = "__FP16_INT32"; const char fp32ToInt8FuncName[] = "__FP32_INT8"; const char fp32ToInt16FuncName[] = "__FP32_INT16"; const char fp32ToInt32FuncName[] = "__FP32_INT32"; +const char boolEqualVVFuncName[] = "__BoolEqualVV"; +const char boolUnEqualVVFuncName[] = "__BoolUnEqualVV"; +const char boolGreaterEqualVVFuncName[] = "__BoolGreaterEqualVV"; +const char boolGreaterVVFuncName[] = "__BoolGreaterVV"; +const char boolLessEqualVVFuncName[] = "__BoolLessEqualVV"; +const char boolLessVVFuncName[] = "__BoolLessThenVV"; +const char fp32ToFp16FuncName[] = "__FP32_FP16"; +const char fp32ToBf16FuncName[] = "__FP32_BF16"; +const char fp32ToTF32FuncName[] = "__FP32_TF32"; +const char andVVFuncName[] = "__AndVV"; +const char orVVFuncName[] = "__OrVV"; +const char xorVVFuncName[] = "__XorVV"; +const char MaxVVFuncName[] = "__MaxVV"; +const char MinVVFuncName[] = "__MinVV"; // Function to declare Tx81 runtime function Value declareTx81Function(ModuleOp module, OpBuilder &builder, Location loc, @@ -141,15 +165,15 @@ struct ConstantOpConversion : public OpConversionPattern { auto resultType = getTypeConverter()->convertType(op.getResult().getType()); // Handle different attribute types - if (auto intAttr = dyn_cast(constAttr)) { + if (auto intAttr = mlir::dyn_cast(constAttr)) { // Convert integer attribute rewriter.replaceOpWithNewOp(op, resultType, intAttr); return success(); - } else if (auto floatAttr = dyn_cast(constAttr)) { + } else if (auto floatAttr = mlir::dyn_cast(constAttr)) { // Convert float attribute rewriter.replaceOpWithNewOp(op, resultType, floatAttr); return success(); - } else if (auto boolAttr = dyn_cast(constAttr)) { + } else if (auto boolAttr = mlir::dyn_cast(constAttr)) { // Convert bool attribute to i1 rewriter.replaceOpWithNewOp( op, resultType, @@ -173,23 +197,25 @@ struct IndexCastOpConversion : public OpConversionPattern { auto dstType = getTypeConverter()->convertType(op.getResult().getType()); // Convert from index to specific integer type - if (isa(srcType) && isa(dstType)) { + if (mlir::isa(srcType) && + mlir::isa(dstType)) { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); return success(); } // Convert from specific integer type to index - if (isa(srcType) && isa(dstType)) { + if (mlir::isa(srcType) && + mlir::isa(dstType)) { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn()); return success(); } // Handle integer to integer casts - if (isa(srcType) && isa(dstType)) { - unsigned srcWidth = cast(srcType).getWidth(); - unsigned dstWidth = cast(dstType).getWidth(); + if (mlir::isa(srcType) && mlir::isa(dstType)) { + unsigned srcWidth = mlir::cast(srcType).getWidth(); + unsigned dstWidth = mlir::cast(dstType).getWidth(); if (srcWidth < dstWidth) { // Sign extend if source is signed, zero extend otherwise @@ -472,10 +498,11 @@ struct ReduceOpConversion : public OpConversionPattern { // uint16_t src_h, uint16_t src_w, uint16_t src_c, uint16_t fmt) auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); // Types for function declaration - SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i32Ty, - i32Ty, i32Ty, i32Ty, i32Ty}; + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty, + i16Ty, i16Ty, i16Ty, i16Ty}; Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), funcPrefix, i8PtrTy, argTypes); @@ -495,17 +522,17 @@ struct ReduceOpConversion : public OpConversionPattern { // Convert shape attribute to Value Value shape_n = - rewriter.create(op.getLoc(), i32Ty, op.getShape()[0]); + rewriter.create(op.getLoc(), i16Ty, op.getShape()[0]); Value shape_h = - rewriter.create(op.getLoc(), i32Ty, op.getShape()[1]); + rewriter.create(op.getLoc(), i16Ty, op.getShape()[1]); Value shape_w = - rewriter.create(op.getLoc(), i32Ty, op.getShape()[2]); + rewriter.create(op.getLoc(), i16Ty, op.getShape()[2]); Value shape_c = - rewriter.create(op.getLoc(), i32Ty, op.getShape()[3]); + rewriter.create(op.getLoc(), i16Ty, op.getShape()[3]); // Handle format attribute Value fmt = rewriter.create( - op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); // Create the call auto call = rewriter.create( @@ -520,7 +547,7 @@ struct ReduceOpConversion : public OpConversionPattern { } }; -// Convert tx81.binary op to LLVM call +// Convert tx81.elementwise op to LLVM call template struct ElementWiseOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -582,6 +609,58 @@ struct ElementWiseOpConversion : public OpConversionPattern { } }; +template +struct UnaryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __Abs(void* src, void* dst, uint32_t elem_count, + // uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value input = adaptor.getInput(); + // Need to bitcast src to i8* + input = rewriter.create(op.getLoc(), i8PtrTy, input); + Value out = adaptor.getOut(); + // Need to bitcast out to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{input, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + // FIXME: Use trait to refactor the BinaryVSOpConversion and // ElementWiseOpConversion template @@ -643,6 +722,123 @@ struct BinaryVSOpConversion : public OpConversionPattern { } }; +template +struct BinaryLogicVVOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Tx81Op::Adaptor; + + LogicalResult + matchAndRewrite(Tx81Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void* __XorVV(void* a, void* b, void* out, uint32_t + // elem_count, uint32_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = { + i8PtrTy, // src0_addr + i8PtrTy, // src1_addr + i8PtrTy, // dst_addr + i32Ty, // elem_count + i32Ty // fmt + }; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand, convert Index to I32 + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + +template +struct BoolRelationVVOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename BoolRelationVVOp::Adaptor; + // using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BoolRelationVVOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->template getParentOfType(); + + // Declare the runtime function if not already declared + // Signature: void __BoolLessEqualVV(uint64_t *src0, uint64_t *src1, + // uint64_t *dst, uint32_t elem_count, uint16_t fmt); + auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i32Ty = rewriter.getI32Type(); + auto i16Ty = rewriter.getI16Type(); + + // Types for function declaration + SmallVector argTypes = {i8PtrTy, i8PtrTy, i8PtrTy, i32Ty, i16Ty}; + + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + funcPrefix, i8PtrTy, argTypes); + + // Convert operands + Value srcA = adaptor.getInput0(); + // Need to bitcast src to i8* + srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); + Value srcB = adaptor.getInput1(); + // Need to bitcast src to i8* + srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); + Value out = adaptor.getOut(); + // Need to bitcast src to i8* + out = rewriter.create(op.getLoc(), i8PtrTy, out); + + // Get elem_count operand + Value elemCount = op.getElemCount(); + elemCount = castIndexToInt32(rewriter, op.getLoc(), elemCount); + + // Handle format attribute + Value fmt = rewriter.create( + op.getLoc(), i16Ty, rewriter.getI16IntegerAttr(op.getFmt())); + + // Create the call + auto call = rewriter.create( + op.getLoc(), i8PtrTy, funcPrefix, // funcPtr, + ArrayRef{srcA, srcB, out, elemCount, fmt}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + // Convert tx81.NormalConvertOp op to LLVM template struct NormalConvertOpConversion : public OpConversionPattern { @@ -656,8 +852,8 @@ struct NormalConvertOpConversion : public OpConversionPattern { auto module = op->template getParentOfType(); // Declare the runtime function if not already declared - // Signature: void (*FP16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - // uint64_t dst_addr, uint32_t elem_count); + // Signature: void __FP16_FP32(uint64_t *src, uint64_t *dst, uint32_t + // elem_count); auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); @@ -702,8 +898,8 @@ struct RoundConvertOpConversion : public OpConversionPattern { auto module = op->template getParentOfType(); // Declare the runtime function if not already declared - // Signature: void (*INT16_FP32)(TsmConvertInstr *instr, uint64_t src0_addr, - // uint64_t dst_addr, uint32_t elem_count, RND_MODE rnd_mode); + // Signature: void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t + // elem_count, RND_MODE round); auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); auto i16Ty = rewriter.getI16Type(); @@ -749,12 +945,11 @@ struct GemmOpConversion : public OpConversionPattern { auto module = op->getParentOfType(); // Declare the __Gemm runtime function if not already declared - // Signature: void* __Gemm(void* a, void* b, void* bias, int32_t* dims, - // void* psum, bool trans_a, bool trans_b, - // uint32_t batch_a, uint32_t batch_b, bool en_relu, - // bool en_bias, bool en_neg_scale, void* neg_scale, - // bool en_pos_scale, void* pos_scale, - // uint32_t src_fmt, uint32_t dst_fmt); + // Signature: void __Gemm(int64_t* srcA, int64_t *srcB, int64_t * srcBias, + // int64_t *dst, int32_t *dims, bool enPsum, int64_t *psum, bool enTransA, + // bool enTransB, int64_t batchSizeA, int64_t batchSizeB, bool enLeakyRelu, + // bool enBias,bool enNegScale, int64_t *negScale, bool enPosScale, int64_t + // *posScale, int64_t srcFmt, int64_t dstFmt) auto i8PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); auto i32Ty = rewriter.getI32Type(); auto i64Ty = rewriter.getI64Type(); @@ -764,9 +959,26 @@ struct GemmOpConversion : public OpConversionPattern { // Types for function declaration SmallVector argTypes = { - i8PtrTy, i8PtrTy, i8PtrTy, i32PtrTy, i8PtrTy, i1Ty, - i1Ty, i32Ty, i32Ty, i1Ty, i1Ty, i1Ty, - i8PtrTy, i1Ty, i8PtrTy, i32Ty, i32Ty}; + i8PtrTy, // srcA + i8PtrTy, // srcB + i8PtrTy, // srcBias + i8PtrTy, // dst + i32PtrTy, // dims + i1Ty, // enPsum + i8PtrTy, // psum + i1Ty, // enTransA + i1Ty, // enTransB + i32Ty, // batchSizeA + i32Ty, // batchSizeB + i32Ty, // reluMode + i1Ty, // enBias + i1Ty, // enNegScale + i8PtrTy, // negScale + i1Ty, // enPosScale + i8PtrTy, // posScale + i32Ty, // srcFmt + i32Ty // dstFmt + }; // Declare the function Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), "__Gemm", @@ -776,6 +988,8 @@ struct GemmOpConversion : public OpConversionPattern { Value srcA = adaptor.getSrcA(); Value srcB = adaptor.getSrcB(); Value srcBias = adaptor.getSrcBias(); + Value dst = adaptor.getDst(); + Value psumAddr = adaptor.getPsumAddr(); Value srcNegScale = adaptor.getSrcNegScale(); Value srcPosScale = adaptor.getSrcPosScale(); @@ -784,6 +998,7 @@ struct GemmOpConversion : public OpConversionPattern { srcA = rewriter.create(op.getLoc(), i8PtrTy, srcA); srcB = rewriter.create(op.getLoc(), i8PtrTy, srcB); srcBias = rewriter.create(op.getLoc(), i8PtrTy, srcBias); + dst = rewriter.create(op.getLoc(), i8PtrTy, dst); psumAddr = rewriter.create(op.getLoc(), i8PtrTy, psumAddr); srcNegScale = @@ -795,25 +1010,16 @@ struct GemmOpConversion : public OpConversionPattern { auto dimsAttr = op.getDims(); SmallVector dimsValues; for (auto dimAttr : dimsAttr) - dimsValues.push_back(cast(dimAttr).getInt()); + dimsValues.push_back(mlir::cast(dimAttr).getInt()); // Allocate memory for the dims array Value dimsArraySize = rewriter.create( - op.getLoc(), i64Ty, - rewriter.getI64IntegerAttr(dimsValues.size() * sizeof(int64_t))); - - // Use malloc to allocate memory for dims array - Value mallocFunc = declareTx81Function(module, rewriter, op.getLoc(), - "malloc", i8PtrTy, {i64Ty}); - Value dimsArray = - rewriter - .create(op.getLoc(), i8PtrTy, "malloc", // mallocFunc, - ArrayRef{dimsArraySize}) - .getResult(); - - // Cast to i64* - Value dimsArrayI64Ptr = - rewriter.create(op.getLoc(), i64PtrTy, dimsArray); + op.getLoc(), i64Ty, rewriter.getI64IntegerAttr(dimsValues.size())); + + // Use alloc to allocate memory for dims array + auto dimsArrayI32Ptr = rewriter.create( + op.getLoc(), i32PtrTy, rewriter.getI32Type(), dimsArraySize, + /*alignment=*/0); // Store each dimension in the array for (size_t i = 0; i < dimsValues.size(); i++) { @@ -823,11 +1029,11 @@ struct GemmOpConversion : public OpConversionPattern { // Create GEP to get pointer to array element Value elemPtr = rewriter.create( - op.getLoc(), i64PtrTy, i64Ty, dimsArrayI64Ptr, ArrayRef{idx}); + op.getLoc(), i64PtrTy, i32Ty, dimsArrayI32Ptr, ArrayRef{idx}); // Create the dimension value Value dimValue = rewriter.create( - op.getLoc(), i64Ty, rewriter.getI64IntegerAttr(dimsValues[i])); + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(dimsValues[i])); // Store the value rewriter.create(op.getLoc(), dimValue, elemPtr); @@ -838,8 +1044,10 @@ struct GemmOpConversion : public OpConversionPattern { op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcA())); Value transB = rewriter.create( op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getTransSrcB())); - Value enLeakyRelu = rewriter.create( - op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnLeakyRelu())); + Value enPSum = rewriter.create( + op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnPsum())); + Value reluMode = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(op.getReluMode())); Value enBias = rewriter.create( op.getLoc(), i1Ty, rewriter.getBoolAttr(op.getEnBias())); Value enNegScale = rewriter.create( @@ -860,9 +1068,10 @@ struct GemmOpConversion : public OpConversionPattern { // Create the call to __Gemm auto call = rewriter.create( op.getLoc(), i8PtrTy, "__Gemm", // funcPtr, - ArrayRef{srcA, srcB, srcBias, dimsArrayI64Ptr, psumAddr, transA, - transB, batchA, batchB, enLeakyRelu, enBias, enNegScale, - srcNegScale, enPosScale, srcPosScale, srcFmt, dstFmt}); + ArrayRef{srcA, srcB, srcBias, dst, dimsArrayI32Ptr, enPSum, + psumAddr, transA, transB, batchA, batchB, reluMode, + enBias, enNegScale, srcNegScale, enPosScale, + srcPosScale, srcFmt, dstFmt}); // Replace the op with the result of the call rewriter.replaceOp(op, call.getResult()); @@ -891,8 +1100,18 @@ struct MemsetOpConversion : public OpConversionPattern { auto i16Ty = rewriter.getI16Type(); // Types for function declaration - SmallVector argTypes = {i8PtrTy, i32Ty, i32Ty, - i32PtrTy, i32PtrTy, i16Ty}; + SmallVector argTypes = { + i8PtrTy, // Spm addr + i32Ty, // value + i32Ty, // shape_n/iterator_2 + i32Ty, // shape_h/iterator_1 + i32Ty, // shape_w/iterator_0 + i32Ty, // shape_c/elem_count + i32Ty, // stride_n + i32Ty, // stride_h + i32Ty, // stride_w, + i16Ty // fmt + }; // Declare the function Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), @@ -900,15 +1119,21 @@ struct MemsetOpConversion : public OpConversionPattern { // Get operands Value src = adaptor.getSrc(); + src = rewriter.create(op.getLoc(), i8PtrTy, src); + Value value = adaptor.getValue(); - Value elemCount = adaptor.getElemCount(); - elemCount = castIndexToInt32(rewriter, op->getLoc(), elemCount); // Handle strides and iterations arrays - // For simplicity, we'll create null pointers - Value nullPtr = rewriter.create(op.getLoc(), i32PtrTy); + ValueRange shape = adaptor.getShape(); + Value iteration2 = castIndexToInt32(rewriter, op->getLoc(), shape[0]); + Value iteration1 = castIndexToInt32(rewriter, op->getLoc(), shape[1]); + Value iteration0 = castIndexToInt32(rewriter, op->getLoc(), shape[2]); + Value elemCount = castIndexToInt32(rewriter, op->getLoc(), shape[3]); - src = rewriter.create(op.getLoc(), i8PtrTy, src); + ValueRange strides = adaptor.getStrides(); + Value stride2 = castIndexToInt32(rewriter, op->getLoc(), strides[0]); + Value stride1 = castIndexToInt32(rewriter, op->getLoc(), strides[1]); + Value stride0 = castIndexToInt32(rewriter, op->getLoc(), strides[2]); // Convert fmt attribute to Value Value fmt = rewriter.create( @@ -917,7 +1142,8 @@ struct MemsetOpConversion : public OpConversionPattern { // Create the call to __Memset auto call = rewriter.create( op.getLoc(), i8PtrTy, "__Memset", // funcPtr, - ArrayRef{src, value, elemCount, nullPtr, nullPtr, fmt}); + ArrayRef{src, value, elemCount, stride0, iteration0, stride1, + iteration1, stride2, iteration2, fmt}); // Replace the op with the result of the call rewriter.replaceOp(op, call.getResult()); @@ -944,7 +1170,7 @@ struct LinalgFillOpConversion : public OpConversionPattern { // Check if the output is a tensor type Value outputTensor = op.getOutputs()[0]; - auto tensorType = dyn_cast(outputTensor.getType()); + auto tensorType = mlir::dyn_cast(outputTensor.getType()); if (!tensorType) { return rewriter.notifyMatchFailure(op, "expects ranked tensor type"); } @@ -963,8 +1189,8 @@ struct LinalgFillOpConversion : public OpConversionPattern { Type elemType = tensorType.getElementType(); // Convert the tensor type to the LLVM pointer type - auto llvmPtrType = - dyn_cast(typeConverter->convertType(tensorType)); + auto llvmPtrType = mlir::dyn_cast( + typeConverter->convertType(tensorType)); if (!llvmPtrType) { return rewriter.notifyMatchFailure( op, "failed to convert tensor type to LLVM pointer type"); @@ -986,10 +1212,10 @@ struct LinalgFillOpConversion : public OpConversionPattern { // Calculate element size in bytes int64_t elemSizeInBytes = 0; - if (auto intType = dyn_cast(elemType)) { + if (auto intType = mlir::dyn_cast(elemType)) { elemSizeInBytes = (intType.getWidth() + 7) / 8; // Round up to nearest byte - } else if (auto floatType = dyn_cast(elemType)) { + } else if (auto floatType = mlir::dyn_cast(elemType)) { elemSizeInBytes = (floatType.getWidth() + 7) / 8; // Round up to nearest byte } else { @@ -1020,7 +1246,7 @@ struct LinalgFillOpConversion : public OpConversionPattern { // For memset to work correctly, we need to have a consistent byte pattern if (auto constOp = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(constOp.getValue())) { + if (auto intAttr = mlir::dyn_cast(constOp.getValue())) { // For integer constants auto intVal = intAttr.getInt(); // Check if all bytes in the pattern are the same @@ -1039,7 +1265,8 @@ struct LinalgFillOpConversion : public OpConversionPattern { loc, rewriter.getIntegerType(8), rewriter.getIntegerAttr(rewriter.getIntegerType(8), firstByte)); } - } else if (auto floatAttr = dyn_cast(constOp.getValue())) { + } else if (auto floatAttr = + mlir::dyn_cast(constOp.getValue())) { // For floating point constants if (floatAttr.getValue().isZero()) { // Zero float can use memset with zero byte value @@ -1166,8 +1393,9 @@ class TensorEmptyOpConversion : public OpConversionPattern { } // Convert the tensor type to LLVM pointer type - auto llvmPtrType = dyn_cast( + auto llvmPtrType = mlir::dyn_cast( getTypeConverter()->convertType(resultType)); + if (!llvmPtrType) { return rewriter.notifyMatchFailure( op, "failed to convert tensor type to LLVM pointer type"); @@ -1208,9 +1436,9 @@ class TensorEmptyOpConversion : public OpConversionPattern { private: // Helper to get element type size in bytes int64_t getElementTypeSize(Type type) const { - if (auto floatType = dyn_cast(type)) { + if (auto floatType = mlir::dyn_cast(type)) { return floatType.getWidth() / 8; - } else if (auto intType = dyn_cast(type)) { + } else if (auto intType = mlir::dyn_cast(type)) { return intType.getWidth() / 8; } // Default for other types @@ -1239,6 +1467,57 @@ class TensorEmptyOpConversion : public OpConversionPattern { } }; +// Convert tt.get_program_id to LLVM call to __get_pid function +// Think this as Tx81 special action. May can separate to a single pass or use +// tx81.get_program_id op +struct GetProgramIDConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + static uint32_t constexpr LAUNCH_GRID_RANK = + mlir::triton::getMaxEnumValForProgramIDDim() + 1; + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Memset runtime function if not already declared + // Signature: uint32_t __get_pid(uint32_t); + auto i32Ty = rewriter.getI32Type(); + + // Types for function declaration + SmallVector argTypes = { + i32Ty, // x: 0/y: 1/z: 2, + }; + + // Declare the function + Value funcPtr = declareTx81Function(module, rewriter, op.getLoc(), + "__get_pid", i32Ty, argTypes); + + // Get operands + auto axis = (uint32_t)op.getAxis(); + + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + // Convert fmt attribute to Value + Value src = rewriter.create( + op.getLoc(), i32Ty, rewriter.getI32IntegerAttr(axis)); + + // Create the call to __Memset + auto call = rewriter.create(op.getLoc(), i32Ty, + "__get_pid", // funcPtr, + ArrayRef{src}); + + // Replace the op with the result of the call + rewriter.replaceOp(op, call.getResult()); + + return success(); + } +}; + // The conversion pass class Tx81ToLLVMPass : public Tx81ToLLVMBase { public: @@ -1264,8 +1543,9 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { func::FuncDialect, math::MathDialect>(); // Handle the tx81 op to llvm.call and support kcore load/store op's spm // offset - target.addIllegalDialect(); + target.addIllegalDialect(); // Setup rewrite patterns RewritePatternSet patterns(context); @@ -1289,22 +1569,46 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { RoundConvertOpConversion, RoundConvertOpConversion, RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, + RoundConvertOpConversion, ReduceOpConversion, ReduceOpConversion, ReduceOpConversion, ElementWiseOpConversion, ElementWiseOpConversion, ElementWiseOpConversion, - ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + ElementWiseOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion, BinaryVSOpConversion, BinaryVSOpConversion, BinaryVSOpConversion, - BinaryVSOpConversion, + BinaryVSOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BoolRelationVVOpConversion, + BinaryLogicVVOpConversion, + BinaryLogicVVOpConversion, + BinaryLogicVVOpConversion, RdmaOpConversion, WdmaOpConversion, MaskMoveOpConversion, GemmOpConversion, - MemsetOpConversion>( + MemsetOpConversion, + GetProgramIDConversion>( context); // clang-format on diff --git a/third_party/tsingmicro/name.conf b/third_party/tsingmicro/name.conf index 8a593a129..1340763be 100644 --- a/third_party/tsingmicro/name.conf +++ b/third_party/tsingmicro/name.conf @@ -1 +1 @@ -ztc +tsingmicro diff --git a/third_party/tsingmicro/python/triton_tsingmicro.cc b/third_party/tsingmicro/python/triton_tsingmicro.cc index ff232d947..608918898 100644 --- a/third_party/tsingmicro/python/triton_tsingmicro.cc +++ b/third_party/tsingmicro/python/triton_tsingmicro.cc @@ -1,7 +1,45 @@ #include +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToCoreDialects/Passes.h" +#include "triton-shared/Conversion/TritonToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "tsingmicro-tx81/Dialect/IR/Tx81Dialect.h" + +#include "magic-kernel/Conversion/CoreDialectsToMK/Passes.h" +#include "magic-kernel/Conversion/LinalgToMK/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "tsingmicro-tx81/Conversion/MKToTx81/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81MemrefToLLVM/Passes.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/Passes.h" + +#include "magic-kernel/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "tsingmicro-tx81/Conversion/Tx81ToLLVM/KernelArgBufferPass.h" + namespace py = pybind11; +using namespace mlir; -// The TsingMicro backend with ztc doesn't do compilation from within python -// but rather externally through ztc-opt, so we leave this function blank. void init_triton_tsingmicro(py::module &&m) {} diff --git a/third_party/tsingmicro/scripts/build_llvm.sh b/third_party/tsingmicro/scripts/build_llvm.sh new file mode 100755 index 000000000..d76a4ef1c --- /dev/null +++ b/third_party/tsingmicro/scripts/build_llvm.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +if [ -z "${LLVM_PROJECT+x}" ]; then + echo "Please set the environment variable “LLVM_PROJECT”." 1>&2 + exit 1 +fi + +if [ ! -d $LLVM_PROJECT ]; then + echo "Error: $LLVM_PROJECT not exist!" 1>&2 + exit 1 +fi + +BUILD_TYPE=Release + +build_llvm() { + mkdir $LLVM_PROJECT/build + cd $LLVM_PROJECT/build + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS="clang;mlir;llvm;lld" \ + -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU;RISCV" \ + -DLLVM_USE_LINKER=lld \ + -DMLIR_ENABLE_BINDINGS_PYTHON=1 \ + ../llvm + ninja +} + +build_llvm diff --git a/third_party/tsingmicro/scripts/build_tsingmicro.sh b/third_party/tsingmicro/scripts/build_tsingmicro.sh new file mode 100755 index 000000000..1e093cf75 --- /dev/null +++ b/third_party/tsingmicro/scripts/build_tsingmicro.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +if [ -z "${WORKSPACE+x}" ]; then + WORKSPACE=$(realpath "$project_dir/..") +fi + +TX8_HOME=$WORKSPACE/tx8_deps +LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 + +if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.flagtree/tsingmicro/" + TX8_HOME=$WORKSPACE/tx8_deps + LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 +fi + +if [ ! -d $TX8_HOME ]; then + echo "Error: $TX8_HOME not exist!" 1>&2 + exit 1 +fi + +if [ ! -d $LLVM ]; then + echo "Error: $LLVM not exist!" 1>&2 + exit 1 +fi + +BUILD_TYPE=Release + +export TX8_HOME=$TX8_HOME +export LLVM_SYSPATH=$LLVM +export FLAGTREE_BACKEND=tsingmicro + +export TRITON_OFFLINE_BUILD=ON +export TRITON_BUILD_WITH_CLANG_LLD=true +export TRITON_BUILD_WITH_CCACHE=true +export TRITON_BUILD_PROTON=OFF + +echo "export TX8_HOME=$TX8_HOME" +echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +echo "export FLAGTREE_BACKEND=$FLAGTREE_BACKEND" + +echo "export TRITON_OFFLINE_BUILD=$TRITON_OFFLINE_BUILD" +echo "export TRITON_BUILD_WITH_CLANG_LLD=$TRITON_BUILD_WITH_CLANG_LLD" +echo "export TRITON_BUILD_WITH_CCACHE=$TRITON_BUILD_WITH_CCACHE" +echo "export TRITON_BUILD_PROTON=$TRITON_BUILD_PROTON" + +cd python +python3 -m pip install . --no-build-isolation -v --verbose diff --git a/third_party/tsingmicro/scripts/install.sh b/third_party/tsingmicro/scripts/install.sh new file mode 100755 index 000000000..b0d3346b4 --- /dev/null +++ b/third_party/tsingmicro/scripts/install.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +apt install git +apt install lld + +pip uninstall triton + +pip install gitpython +pip install torch==2.7.0 torchvision diff --git a/third_party/tsingmicro/scripts/run_tsingmicro.sh b/third_party/tsingmicro/scripts/run_tsingmicro.sh new file mode 100755 index 000000000..13e3ed38c --- /dev/null +++ b/third_party/tsingmicro/scripts/run_tsingmicro.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +script_path=$(realpath "$0") +script_dir=$(dirname "$script_path") +project_dir=$(realpath "$script_dir/../../..") + +if [ -z "${WORKSPACE+x}" ]; then + WORKSPACE=$(realpath "$project_dir/..") +fi + +TX8_HOME=$WORKSPACE/tx8_deps +LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 + +if [ ! -d $TX8_HOME ] || [ ! -d $LLVM ]; then + WORKSPACE="${HOME}/.flagtree/tsingmicro/" + TX8_HOME=$WORKSPACE/tx8_deps + LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 +fi + +if [ ! -d $TX8_HOME ]; then + echo "Error: $TX8_HOME not exist!" 1>&2 + exit 1 +fi + +if [ ! -d $LLVM ]; then + echo "Error: $LLVM not exist!" 1>&2 + exit 1 +fi + +export TX8_HOME=$TX8_HOME +export LLVM_SYSPATH=$LLVM +export LD_LIBRARY_PATH=$TX8_HOME/lib:$LD_LIBRARY_PATH +export TRITON_ALWAYS_COMPILE=1 + +# export TRITON_DUMP_PATH=$project_dir/dump + +echo "export TX8_HOME=$TX8_HOME" +echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH" +echo "export TRITON_ALWAYS_COMPILE=$TRITON_ALWAYS_COMPILE" + +python3 $@ From 6d5de1cac439adc4f2007950f95930c5d773cd54 Mon Sep 17 00:00:00 2001 From: tsingmicro-public Date: Thu, 5 Jun 2025 21:08:10 +0800 Subject: [PATCH 06/12] [BUILD] Fix build tsingmicro * [BUILD] Fix build tsingmicro --------- Co-authored-by: tsingmicro-public Co-authored-by: zhengyang --- CMakeLists.txt | 5 +++- bin/RegisterTritonDialects.h | 2 +- python/setup_helper.py | 4 +++ third_party/tsingmicro/backend/driver.py | 38 ++++++++++++------------ 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f6380ecd0..1476127de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,9 +26,12 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND}) -elseif(FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro)$") +elseif(FLAGTREE_BACKEND STREQUAL "aipu") add_definitions(-D__NVIDIA__) add_definitions(-D__AMD__) +elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") + set(CMAKE_C_COMPILER clang) + set(CMAKE_CXX_COMPILER clang++) endif() set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") if(FLAGTREE_PLUGIN) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 04d24dff8..3329aada1 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -69,8 +69,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::registerAllocateSharedMemoryPass(); mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); - mlir::triton::registerConvertWarpSpecializeToLLVM(); #ifdef __NVIDIA__ + mlir::triton::registerConvertWarpSpecializeToLLVM(); mlir::triton::registerConvertTritonGPUToLLVMPass(); mlir::triton::registerConvertNVGPUToLLVMPass(); #endif diff --git a/python/setup_helper.py b/python/setup_helper.py index 454c43b70..635aac656 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -373,3 +373,7 @@ def check_env(env_val): pre_hock=lambda: check_env('LLVM_BUILD_DIR'), post_hock=set_llvm_env, ) + +# tsingmicro +if flagtree_backend == "tsingmicro": + set_env({"TRITON_BUILD_PROTON": "OFF"}) diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py index 05e637c04..dc646291b 100644 --- a/third_party/tsingmicro/backend/driver.py +++ b/third_party/tsingmicro/backend/driver.py @@ -17,25 +17,23 @@ from triton.backends.driver import GPUDriver from triton.backends.compiler import GPUTarget -import torch -from torch.utils import cpp_extension, rename_privateuse1_backend, generate_methods_for_privateuse1_backend -module = cpp_extension.load( - name="txda", - sources=[os.path.dirname(__file__) + "/txda_device.cpp"], - #runtime include path - extra_include_paths=[""], - #runtime *.so path - extra_ldflags=[""], - extra_cflags=["-g"], - verbose=True, -) - -torch.utils.rename_privateuse1_backend("txda") - -torch._register_device_module("txda", module) - -generate_methods_for_privateuse1_backend(for_storage=True) +def extend_torch(): + import torch + from torch.utils import cpp_extension, rename_privateuse1_backend, generate_methods_for_privateuse1_backend + module = cpp_extension.load( + name="txda", + sources=[os.path.dirname(__file__) + "/txda_device.cpp"], + #runtime include path + extra_include_paths=[""], + #runtime *.so path + extra_ldflags=[""], + extra_cflags=["-g"], + verbose=True, + ) + torch.utils.rename_privateuse1_backend("txda") + torch._register_device_module("txda", module) + generate_methods_for_privateuse1_backend(for_storage=True) def _get_tx8_path(bin_name: str) -> str: @@ -709,8 +707,10 @@ class TXDADriver(GPUDriver): def __init__(self): super().__init__() + extend_torch() self.utils = TXDAUtils() self.launcher_cls = TXDALauncher + import torch # Needs to overwrite GPUDriver base methods self.get_current_stream = torch.txda.current_stream self.get_current_device = torch.txda.current_device @@ -732,7 +732,7 @@ def get_current_target(self): return GPUTarget("txda", capability, warp_size) def get_active_torch_device(self): - # import torch + import torch # torch.txda.init_device() return torch.device("txda", self.get_current_device()) From a4dca036ce5b31cfaafc4cbf9c004d19eabd5a43 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 10:04:16 +0800 Subject: [PATCH 07/12] [BACKEND] [BUILD] Fix tsingmicro build --- third_party/tsingmicro/CMakeLists.txt | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt index 8d0ee5de1..ef84b07e7 100644 --- a/third_party/tsingmicro/CMakeLists.txt +++ b/third_party/tsingmicro/CMakeLists.txt @@ -17,22 +17,6 @@ add_subdirectory(lib) add_subdirectory(bin) add_subdirectory(crt) if(TRITON_BUILD_PYTHON_MODULE) - # find_package(Python3 REQUIRED COMPONENTS Development Interpreter) # 添加查找 Python3 - # add_library(backendxxxTritonPlugin SHARED - # ${CMAKE_CURRENT_SOURCE_DIR}/triton_backendxxx.cc - # ) - # set_target_properties(backendxxxTritonPlugin PROPERTIES - # PREFIX "" - # LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib - # POSITION_INDEPENDENT_CODE ON - # ) - # target_link_libraries(backendxxxTritonPlugin PRIVATE # 链接闭源模块,此处为两个⽰例 - # BackendTritonGPUToLLVM - # BackendTritonTransforms - # # Py - # ${Python3_LIBRARIES} # 添加链接 Python3 - # ${PYTHON_LDFLAGS} - # ) # FIXME: Unify the libraries for TsingMicro into fewer ones add_triton_plugin(TritonTsingMicro ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_tsingmicro.cc LINK_LIBS ZTCAnalysis ZTCAnalysisStructured MagicKernelIR From 09e4d4fb71ea0778737a9a26424636ad144cf135 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 11:34:51 +0800 Subject: [PATCH 08/12] [BUILD] Fix tsingmicro build --- python/setup_helper.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/setup_helper.py b/python/setup_helper.py index 635aac656..c90dea31a 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -10,13 +10,13 @@ import hashlib from dataclasses import dataclass +flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() +flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() use_triton_shared = False -necessary_third_party = ["flir"] +necessary_third_party = ["" if flagtree_backend == "tsingmicro" else "flir"] default_backends = ["nvidia", "amd"] extend_backends = [] ext_sourcedir = "triton/_C/" -flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() -flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() @dataclass @@ -284,7 +284,8 @@ def git_clone(lib, lib_path): "so we couldn't compile triton_shared\n") third_partys = [] - third_partys.append(flagtree_backend_info["flir"]) + if flagtree_backend != "tsingmicro": + third_partys.append(flagtree_backend_info["flir"]) if os.environ.get("USE_TRITON_SHARED", "ON") == "ON": third_partys.append(flagtree_backend_info["triton_shared"]) else: @@ -307,7 +308,8 @@ def handle_flagtree_backend(): extend_backends.append(flagtree_backend) if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"): ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - default_backends.append("flir") + if flagtree_backend != "tsingmicro": + default_backends.append("flir") if use_triton_shared: default_backends.append("triton_shared") @@ -375,5 +377,5 @@ def check_env(env_val): ) # tsingmicro -if flagtree_backend == "tsingmicro": - set_env({"TRITON_BUILD_PROTON": "OFF"}) +#if flagtree_backend == "tsingmicro": +# set_env({"TRITON_BUILD_PROTON": "OFF"}) From 807167d1bdfd6b03204e76924abf6a72833e29b3 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 06:19:26 +0000 Subject: [PATCH 09/12] [BUILD] Update build tsingmicro --- python/setup_helper.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/setup_helper.py b/python/setup_helper.py index c90dea31a..fc01e7627 100644 --- a/python/setup_helper.py +++ b/python/setup_helper.py @@ -39,7 +39,6 @@ class FlagTreeBackend: } set_llvm_env = lambda path: set_env({ - 'LLVM_BUILD_DIR': path, 'LLVM_INCLUDE_DIRS': Path(path) / "include", 'LLVM_LIBRARY_DIR': Path(path) / "lib", 'LLVM_SYSPATH': path, @@ -337,7 +336,7 @@ def check_env(env_val): file="iluvatar-llvm18-x86_64", condition=("iluvatar" == flagtree_backend), url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) @@ -346,7 +345,7 @@ def check_env(env_val): file="XTDK-llvm18-ubuntu2004_x86_64", condition=("xpu" == flagtree_backend), url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm18-ubuntu2004_x86_64.tar", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) @@ -357,10 +356,10 @@ def check_env(env_val): cache.store( files=("clang", "xpu-xxd", "xpu3-crt.xpu", "xpu-kernel.t", "ld.lld", "llvm-readelf", "llvm-objdump", "llvm-objcopy"), condition=("xpu" == flagtree_backend), - copy_src_path=f"{os.environ.get('LLVM_BUILD_DIR','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") + copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") cache.store(files=("libclang_rt.builtins-xpu3.a", "libclang_rt.builtins-xpu3s.a"), - condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_BUILD_DIR','')}/lib/linux", + condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/lib/linux", copy_dst_path="third_party/xpu/backend/xpu3/lib/linux") cache.store(files=("include", "so"), condition=("xpu" == flagtree_backend), @@ -372,10 +371,16 @@ def check_env(env_val): condition=("mthreads" == flagtree_backend), url= "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz", - pre_hock=lambda: check_env('LLVM_BUILD_DIR'), + pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) # tsingmicro -#if flagtree_backend == "tsingmicro": -# set_env({"TRITON_BUILD_PROTON": "OFF"}) +cache.store( + file="tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64", + condition=("tsingmicro" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) From 8f0a08c3a065b036b51ac1dffdc9cde7d6df6672 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 06:57:50 +0000 Subject: [PATCH 10/12] [CI/CD] Add tsingmicro build workflow --- .../workflows/tsingmicro-build-and-test.yml | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .github/workflows/tsingmicro-build-and-test.yml diff --git a/.github/workflows/tsingmicro-build-and-test.yml b/.github/workflows/tsingmicro-build-and-test.yml new file mode 100644 index 000000000..0e7af1c86 --- /dev/null +++ b/.github/workflows/tsingmicro-build-and-test.yml @@ -0,0 +1,59 @@ +name: Tsingmicro-Build-And-Test + +on: + push: + branches: [ "triton_v3.3.x" ] + pull_request: + branches: [ "triton_v3.3.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + tsingmicro-build-and-test: + runs-on: tsingmicro + steps: + - name: Checkout code (attempt 1) + id: checkout1 + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before checkout2 + if: steps.checkout1.outcome == 'failure' + run: | + echo "First checkout attempt failed. Sleeping for 120 seconds before retry..." + sleep 120 + + - name: Checkout code (attempt 2) + id: checkout2 + if: steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before final checkout + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + run: | + echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..." + sleep 180 + + - name: Checkout code (final attempt) + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + + - name: Verify checkout success + if: success() + run: echo "Checkout completed successfully" + + - name: FlagTree Build on Tsingmicro + shell: bash + run: | + source ~/env.sh + export FLAGTREE_BACKEND=tsingmicro + cd python + python3.10 -m pip install . --no-build-isolation -v + + - name: FlagTree Test on Tsingmicro + shell: bash + run: | + python3.10 -c 'import triton; print(triton.__path__)' From 80d5a3d010af5670188d1e3aa85e9affc88c907a Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 15:11:55 +0800 Subject: [PATCH 11/12] [DOC] Update tsingmicro build readme --- README.md | 21 ++++++++++----------- README_cn.md | 21 ++++++++++----------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 88054b36b..68aa9d348 100644 --- a/README.md +++ b/README.md @@ -43,17 +43,6 @@ export FLAGTREE_BACKEND=xpu python3 -m pip install . --no-build-isolation -v ``` ```shell -# tsingmicro -# Recommended: Use the Docker image (xxGB) https://xxxx -mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro -wget https://github.com/FlagTree/flagtree/releases/download/xxxx -wget https://github.com/FlagTree/flagtree/releases/download/xxxx -cd ${YOUR_CODE_DIR}/flagtree/ -./third_party/tsingmicro/scripts/install.sh -./third_party/tsingmicro/scripts/build_tsingmicro.sh -./third_party/tsingmicro/scripts/run_tsingmicro.sh third_party/tsingmicro/examples/test_vec_add.py -``` -```shell # mthreads # Recommended: Use the Dockerfile flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads @@ -62,6 +51,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# tsingmicro +# Recommended: Use Ubuntu 20.04 +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/ +git checkout -b triton_v3.3.x origin/triton_v3.3.x +export FLAGTREE_BACKEND=tsingmicro +python3 -m pip install . --no-build-isolation -v +``` To build with default backends (nvidia, amd, triton_shared): ```shell diff --git a/README_cn.md b/README_cn.md index aa5c4575a..11d9b2358 100644 --- a/README_cn.md +++ b/README_cn.md @@ -43,17 +43,6 @@ export FLAGTREE_BACKEND=xpu python3 -m pip install . --no-build-isolation -v ``` ```shell -# tsingmicro -# Recommended: Use the Docker image (xxGB) https://xxxx -mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro -wget https://github.com/FlagTree/flagtree/releases/download/xxxx -wget https://github.com/FlagTree/flagtree/releases/download/xxxx -cd ${YOUR_CODE_DIR}/flagtree/ -./third_party/tsingmicro/scripts/install.sh -./third_party/tsingmicro/scripts/build_tsingmicro.sh -./third_party/tsingmicro/scripts/run_tsingmicro.sh third_party/tsingmicro/examples/test_vec_add.py -``` -```shell # mthreads # 推荐使用镜像 flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads @@ -62,6 +51,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=mthreads python3 -m pip install . --no-build-isolation -v ``` +```shell +# tsingmicro +# 推荐使用镜像 Ubuntu 20.04 +mkdir -p ~/.flagtree/tsingmicro; cd ~/.flagtree/tsingmicro +wget https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.35-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/ +git checkout -b triton_v3.3.x origin/triton_v3.3.x +export FLAGTREE_BACKEND=tsingmicro +python3 -m pip install . --no-build-isolation -v +``` 使用默认的编译命令,可以编译安装 nvidia、amd、triton_shared 后端: ```shell From e68157cb2fa866f56bd0bb4f46275f97113df511 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Fri, 6 Jun 2025 15:21:54 +0800 Subject: [PATCH 12/12] [CI] Fix tsingmicro workflow --- .github/workflows/tsingmicro-build-and-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tsingmicro-build-and-test.yml b/.github/workflows/tsingmicro-build-and-test.yml index 0e7af1c86..9323774a9 100644 --- a/.github/workflows/tsingmicro-build-and-test.yml +++ b/.github/workflows/tsingmicro-build-and-test.yml @@ -56,4 +56,5 @@ jobs: - name: FlagTree Test on Tsingmicro shell: bash run: | + source ~/env.sh python3.10 -c 'import triton; print(triton.__path__)'