Skip to content

Commit

Permalink
Add conversion for tt.reshape to tensor.reshape (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 15, 2024
1 parent 3230526 commit c70dab3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,32 @@ class CumSumConverter : public OpConversionPattern<triton::ScanOp> {
}
};

class ReshapeConverter : public OpConversionPattern<triton::ReshapeOp> {
using OpConversionPattern<triton::ReshapeOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getSrc();
auto output = op.getResult();

auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
return failure();
}
ArrayRef<int64_t> outputShape = outputType.getShape();

auto shape = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64TensorAttr(outputShape));
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, outputType, input,
shape);

return success();
}
};

} // namespace

void mlir::triton::populateTritonToLinalgCanonicalizationPatterns(
Expand Down Expand Up @@ -1697,6 +1723,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
patterns.add<DenseConstantConverter>(patterns.getContext());
patterns.add<UnrealizedCastConverter>(patterns.getContext());
patterns.add<CumSumConverter>(patterns.getContext());
patterns.add<ReshapeConverter>(patterns.getContext());

// Reduce converters
// Triton's reduce op is idential to linalg.reduce op, so we can clone
Expand Down
49 changes: 49 additions & 0 deletions test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s
module {
tt.func public @bcast_kernel_01(%arg0: !tt.ptr<f32, 1>, %arg1: !tt.ptr<f32, 1>) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c32_i32 : i32
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%3 = tt.splat %1 : (i32) -> tensor<32xi32>
%4 = arith.addi %3, %2 : tensor<32xi32>
%5 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32>
%6 = tt.splat %1 : (i32) -> tensor<2048xi32>
%7 = arith.addi %6, %5 : tensor<2048xi32>
%8 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<32x!tt.ptr<f32, 1>>
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32, 1>>, tensor<32xi32>
%10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32xf32>
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
%12 = tt.broadcast %11 : (tensor<1x32xf32>) -> tensor<64x32xf32>
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
%14 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<2048x!tt.ptr<f32, 1>>
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32, 1>>, tensor<2048xi32>
tt.store %15, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<2048xf32>
tt.return
}
}


// CHECK-LABEL: func.func @bcast_kernel_01(
// CHECK: %[[C2048_I64:.*]] = arith.constant 2048 : i64
// CHECK: %[[CST:.*]] = arith.constant dense<[1, 32]> : tensor<2xi64>
// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
// CHECK: %[[VAR_0:.*]] = arith.muli %arg5, %[[C32_I32]] : i32
// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index
// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>>
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32>
// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32>
// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32>
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[VAR_2]](%[[CST]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32>
// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32>
// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[RESHAPE]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: ^bb0(%in: f32, %out: f32):
// CHECK: linalg.yield %in : f32
// CHECK: } -> tensor<64x32xf32>
// CHECK: %[[VAR_5:.*]] = tensor.empty() : tensor<1xi64>
// CHECK: %[[VAR_6:.*]] = linalg.fill ins(%[[C2048_I64]] : i64) outs(%[[VAR_5]] : tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPE_0:.*]] = tensor.reshape %[[VAR_4]](%[[VAR_6]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32>
// CHECK: %[[VAR_7:.*]] = arith.index_cast %[[VAR_0]] : i32 to index
// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %arg1 to offset: [%[[VAR_7]]], sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>>
// CHECK: bufferization.materialize_in_destination %[[RESHAPE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> ()
// CHECK: return

0 comments on commit c70dab3

Please sign in to comment.