Skip to content

mlir PadOp Tiling Interface implement does not include generateResultTileValue #64092

@hesse-x

Description

@hesse-x

We found that the implementation of the PadOp TilingInterface:PadOpTiling does not include the generateResultTileValue interface. This has led to unexpected results when we attempt to fuse the pad operation into a containing loop using transform.structured.fuse_into_containing_op.I would like to understand if this is intended behavior or a defect. Additionally, I would like to know the correct approach to fuse the pad operation into a loop.

Here are our test results on llvm-15.x

#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @pad(%arg0: tensor<58x1xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> attributes {pad} {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %c32 = arith.constant 32 : index
  %zero = arith.constant 0.0 : f32
 
  %pad = tensor.pad %arg0 low[4, 60] high[2, 67] {
  ^bb0(%arg4: index, %arg5: index):
    tensor.yield %zero : f32
  } : tensor<58x1xf32> to tensor<64x128xf32>

  %1 = linalg.init_tensor [64, 128] : tensor<64x128xf32>
  %2 = scf.foreach_thread (%arg2, %arg3) in (%c1, %c1) -> (tensor<64x128xf32>) {
    %3 = tensor.extract_slice %arg1[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
    %4 = tensor.extract_slice %pad[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
    %5 = linalg.init_tensor [8, 128] : tensor<8x128xf32>
    %6 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
        ins(%3, %4 : tensor<8x128xf32>, tensor<8x128xf32>)
        outs(%5 : tensor<8x128xf32>) {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
        %7 = arith.addf %arg4, %arg5 : f32
        linalg.yield %7 : f32
    } -> tensor<8x128xf32>
    scf.foreach_thread.perform_concurrently {
      tensor.parallel_insert_slice %6 into %1[%arg2, %arg3] [8, 128] [1, 1] : tensor<8x128xf32> into tensor<64x128xf32>
    }  
  } {thread_dim_mapping = [2, 4]} 
  return %2 : tensor<64x128xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
  transform.genesis.canonicalized_sequence %arg0 failures(propagate) {
  ^bb0(%arg1: !pdl.operation):
    %device_func = transform.genesis.match ops{["func.func"]} attributes {pad} in %arg1
    %foreach_thread_op = transform.genesis.match ops{["scf.foreach_thread"]} in %device_func

    // fuse and tile
    %expand_shape = transform.genesis.match ops{["tensor.pad"]} in %device_func
    transform.structured.fuse_into_containing_op %expand_shape into %foreach_thread_op
  }  
}
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @pad(%arg0: tensor<58x1xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> attributes {pad} {
    %c1 = arith.constant 1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = linalg.init_tensor [64, 128] : tensor<64x128xf32>
    %1 = scf.foreach_thread (%arg2, %arg3) in (%c1, %c1) -> (tensor<64x128xf32>) {
      %2 = tensor.extract_slice %arg1[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
      %3 = tensor.pad %arg0 low[4, 60] high[2, 67] {
      ^bb0(%arg4: index, %arg5: index):
        tensor.yield %cst : f32
      } : tensor<58x1xf32> to tensor<64x128xf32>
      %4 = tensor.extract_slice %3[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
      %5 = linalg.init_tensor [8, 128] : tensor<8x128xf32>
      %6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %4 : tensor<8x128xf32>, tensor<8x128xf32>) outs(%5 : tensor<8x128xf32>) {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
        %7 = arith.addf %arg4, %arg5 : f32
        linalg.yield %7 : f32
      } -> tensor<8x128xf32>
      scf.foreach_thread.perform_concurrently {
        tensor.parallel_insert_slice %6 into %0[%arg2, %arg3] [8, 128] [1, 1] : tensor<8x128xf32> into tensor<64x128xf32>
      }
    } {thread_dim_mapping = [2, 4]}
    return %1 : tensor<64x128xf32>
  }
  transform.with_pdl_patterns {
  ^bb0(%arg0: !pdl.operation):
    transform.genesis.canonicalized_sequence %arg0 failures(propagate) {
    ^bb0(%arg1: !pdl.operation):
      %0 = transform.genesis.match ops{["func.func"]} attributes {pad} in %arg1
      %1 = transform.genesis.match ops{["scf.foreach_thread"]} in %0
      %2 = transform.genesis.match ops{["tensor.pad"]} in %0
      %3 = transform.structured.fuse_into_containing_op %2 into %1
    }
  }
}

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions