Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Enable reshape propagation through tensor.pad #17492

Open
Max191 opened this issue May 23, 2024 · 0 comments
Open

[Flow] Enable reshape propagation through tensor.pad #17492

Max191 opened this issue May 23, 2024 · 0 comments
Assignees

Comments

@Max191
Copy link
Contributor

Max191 commented May 23, 2024

When trying to fuse tensor.pad with producers, reshapes can be blocking fusion unnecessarily. The following IR is an example of this from VAE:

  %168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%166, %167, %expanded_332, %expanded_333 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%157 : tensor<32x16x256x256xf32>) {
  ^bb0(%in: f32, %in_618: f32, %in_619: f32, %in_620: f32, %out: f32):
    %384 = arith.divf %in_618, %cst_97 : f32
    %385 = arith.addf %384, %cst_93 : f32
    %386 = math.rsqrt %385 : f32
    %387 = arith.mulf %in, %386 : f32
    %388 = arith.mulf %387, %in_619 : f32
    %389 = arith.addf %388, %in_620 : f32
    %390 = arith.negf %389 : f32
    %391 = math.exp %390 : f32
    %392 = arith.addf %391, %cst_91 : f32
    %393 = arith.divf %cst_91, %392 : f32
    %394 = arith.mulf %393, %389 : f32
    linalg.yield %394 : f32
  } -> tensor<32x16x256x256xf32>
  %collapsed_334 = tensor.collapse_shape %168 [[0, 1], [2], [3]] : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
  %padded_335 = tensor.pad %collapsed_334 low[0, 1, 1] high[0, 1, 1] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %cst_64 : f32
  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>

The tensor.collapse_shape does not touch the collapsed dimensions, so if the reshape were propagated through the tensor.pad op, then the two ops could fuse into a dispatch.

One way to do this would be to add reshape propagation patterns for tensor.pad (like https://github.com/llvm/llvm-project/blob/af31883341a122a7285e9b4f0a034470024021eb/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L922), but it may be tricky to manage the propagations.

@Max191 Max191 self-assigned this May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant