From c70dab3c84d3f8115e76b7876bf8a93bade521d0 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Tue, 16 Jan 2024 01:22:19 +0800 Subject: [PATCH] Add conversion for tt.reshape to tensor.reshape (#83) --- .../TritonToLinalg/TritonToLinalg.cpp | 27 ++++++++++ .../convert_tensor_reshape.mlir | 49 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index 3cfd89e4..7afcbec0 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -1662,6 +1662,32 @@ class CumSumConverter : public OpConversionPattern { } }; +class ReshapeConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getSrc(); + auto output = op.getResult(); + + auto outputType = dyn_cast(output.getType()); + if (!outputType) { + return failure(); + } + ArrayRef outputShape = outputType.getShape(); + + auto shape = rewriter.create( + loc, rewriter.getI64TensorAttr(outputShape)); + rewriter.replaceOpWithNewOp(op, outputType, input, + shape); + + return success(); + } +}; + } // namespace void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( @@ -1697,6 +1723,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // Reduce converters // Triton's reduce op is idential to linalg.reduce op, so we can clone diff --git a/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir b/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir new file mode 100644 index 00000000..57d8efc0 --- /dev/null +++ b/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir @@ -0,0 +1,49 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @bcast_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : (i32) -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> + %6 = tt.splat %1 : (i32) -> tensor<2048xi32> + %7 = arith.addi %6, %5 : tensor<2048xi32> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<32x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32xf32> + %11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32> + %12 = tt.broadcast %11 : (tensor<1x32xf32>) -> tensor<64x32xf32> + %13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32> + %14 = tt.splat %arg1 : (!tt.ptr) -> tensor<2048x!tt.ptr> + %15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr>, tensor<2048xi32> + tt.store %15, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<2048xf32> + tt.return + } +} + + +// CHECK-LABEL: func.func @bcast_kernel_01( +// CHECK: %[[C2048_I64:.*]] = arith.constant 2048 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<[1, 32]> : tensor<2xi64> +// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK: %[[VAR_0:.*]] = arith.muli %arg5, %[[C32_I32]] : i32 +// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index +// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32> +// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32> +// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[VAR_2]](%[[CST]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32> +// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32> +// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[RESHAPE]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array} { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<64x32xf32> +// CHECK: %[[VAR_5:.*]] = tensor.empty() : tensor<1xi64> +// CHECK: %[[VAR_6:.*]] = linalg.fill ins(%[[C2048_I64]] : i64) outs(%[[VAR_5]] : tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[RESHAPE_0:.*]] = tensor.reshape %[[VAR_4]](%[[VAR_6]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32> +// CHECK: %[[VAR_7:.*]] = arith.index_cast %[[VAR_0]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %arg1 to offset: [%[[VAR_7]]], sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[RESHAPE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> () +// CHECK: return