Skip to content

Eliminate redundant computation when fusing producer with multiple uses #298

@Yun-Fly

Description

@Yun-Fly
func.func @test(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) {
    %cst_0 = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32>
    %dest0 = tensor.empty() : tensor<16x32x32xf32>
    %0 = linalg.powf ins(%arg0, %cst_0 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%dest0 : tensor<16x32x32xf32>) -> tensor<16x32x32xf32>
    %dest1 = tensor.empty() : tensor<16x32xf32>
    %1 = linalg.reduce { arith.addf } ins(%0 : tensor<16x32x32xf32>) outs(%dest1 : tensor<16x32xf32>) dimensions = [2]
    return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32>
  }

For above input MLIR where powf has two uses, current main branch will generate redundant computation regarding to powf like below:

func.func @test(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) {
    %cst = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32>
    %0 = tensor.empty() : tensor<16x32x32xf32>
    %1 = scf.forall (%arg1) in (16) shared_outs(%arg2 = %0) -> (tensor<16x32x32xf32>) {
      %extracted_slice = tensor.extract_slice %arg0[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %extracted_slice_0 = tensor.extract_slice %cst[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %extracted_slice_1 = tensor.extract_slice %arg2[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %4 = linalg.powf ins(%extracted_slice, %extracted_slice_0 : tensor<1x32x32xf32>, tensor<1x32x32xf32>) outs(%extracted_slice_1 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg2[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<1x32x32xf32> into tensor<16x32x32xf32>
      }
    }
    %2 = tensor.empty() : tensor<16x32xf32>
    %3 = scf.forall (%arg1) in (16) shared_outs(%arg2 = %2) -> (tensor<16x32xf32>) {
      %extracted_slice = tensor.extract_slice %arg0[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %extracted_slice_0 = tensor.extract_slice %cst[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %extracted_slice_1 = tensor.extract_slice %0[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : tensor<16x32x32xf32> to tensor<1x32x32xf32>
      %4 = linalg.powf ins(%extracted_slice, %extracted_slice_0 : tensor<1x32x32xf32>, tensor<1x32x32xf32>) outs(%extracted_slice_1 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>
      %extracted_slice_2 = tensor.extract_slice %arg2[%arg1, 0] [1, 32] [1, 1] : tensor<16x32xf32> to tensor<1x32xf32>
      %reduced = linalg.reduce { arith.addf } ins(%4 : tensor<1x32x32xf32>) outs(%extracted_slice_2 : tensor<1x32xf32>) dimensions = [2] 
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %reduced into %arg2[%arg1, 0] [1, 32] [1, 1] : tensor<1x32xf32> into tensor<16x32xf32>
      }
    }
    return %1, %3 : tensor<16x32x32xf32>, tensor<16x32xf32>
  }

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions