-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Open
Description
We found that the implementation of the PadOp TilingInterface:PadOpTiling does not include the generateResultTileValue interface. This has led to unexpected results when we attempt to fuse the pad operation into a containing loop using transform.structured.fuse_into_containing_op.I would like to understand if this is intended behavior or a defect. Additionally, I would like to know the correct approach to fuse the pad operation into a loop.
Here are our test results on llvm-15.x
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @pad(%arg0: tensor<58x1xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> attributes {pad} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%zero = arith.constant 0.0 : f32
%pad = tensor.pad %arg0 low[4, 60] high[2, 67] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %zero : f32
} : tensor<58x1xf32> to tensor<64x128xf32>
%1 = linalg.init_tensor [64, 128] : tensor<64x128xf32>
%2 = scf.foreach_thread (%arg2, %arg3) in (%c1, %c1) -> (tensor<64x128xf32>) {
%3 = tensor.extract_slice %arg1[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
%4 = tensor.extract_slice %pad[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
%5 = linalg.init_tensor [8, 128] : tensor<8x128xf32>
%6 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
ins(%3, %4 : tensor<8x128xf32>, tensor<8x128xf32>)
outs(%5 : tensor<8x128xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%7 = arith.addf %arg4, %arg5 : f32
linalg.yield %7 : f32
} -> tensor<8x128xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %6 into %1[%arg2, %arg3] [8, 128] [1, 1] : tensor<8x128xf32> into tensor<64x128xf32>
}
} {thread_dim_mapping = [2, 4]}
return %2 : tensor<64x128xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.genesis.canonicalized_sequence %arg0 failures(propagate) {
^bb0(%arg1: !pdl.operation):
%device_func = transform.genesis.match ops{["func.func"]} attributes {pad} in %arg1
%foreach_thread_op = transform.genesis.match ops{["scf.foreach_thread"]} in %device_func
// fuse and tile
%expand_shape = transform.genesis.match ops{["tensor.pad"]} in %device_func
transform.structured.fuse_into_containing_op %expand_shape into %foreach_thread_op
}
}
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @pad(%arg0: tensor<58x1xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> attributes {pad} {
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.init_tensor [64, 128] : tensor<64x128xf32>
%1 = scf.foreach_thread (%arg2, %arg3) in (%c1, %c1) -> (tensor<64x128xf32>) {
%2 = tensor.extract_slice %arg1[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
%3 = tensor.pad %arg0 low[4, 60] high[2, 67] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %cst : f32
} : tensor<58x1xf32> to tensor<64x128xf32>
%4 = tensor.extract_slice %3[%arg2, %arg3] [8, 128] [1, 1] : tensor<64x128xf32> to tensor<8x128xf32>
%5 = linalg.init_tensor [8, 128] : tensor<8x128xf32>
%6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %4 : tensor<8x128xf32>, tensor<8x128xf32>) outs(%5 : tensor<8x128xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%7 = arith.addf %arg4, %arg5 : f32
linalg.yield %7 : f32
} -> tensor<8x128xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %6 into %0[%arg2, %arg3] [8, 128] [1, 1] : tensor<8x128xf32> into tensor<64x128xf32>
}
} {thread_dim_mapping = [2, 4]}
return %1 : tensor<64x128xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.genesis.canonicalized_sequence %arg0 failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.genesis.match ops{["func.func"]} attributes {pad} in %arg1
%1 = transform.genesis.match ops{["scf.foreach_thread"]} in %0
%2 = transform.genesis.match ops{["tensor.pad"]} in %0
%3 = transform.structured.fuse_into_containing_op %2 into %1
}
}
}