Skip to content

Conversation

@zhczhong
Copy link
Contributor

Track Issue: #301

The tiling and fusion related pass has already generated the outer parallel loop while the linalgToParallelLoopPass will create the parallel loop again. And there will be nested omp parallel which hurt the performance a lot.

Before this PR

module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 32 : i32>>>} {
  func.func @entry(%arg0: memref<1024x1024xf32>, %arg1: memref<128x1024xf32>, %arg2: memref<128x1024xf32>, %arg3: memref<128x1024xf32>) {
    %c1024 = arith.constant 1024 : index
    %c128 = arith.constant 128 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = cpuruntime.alloc() : memref<128x1024xf32>
    memref.copy %arg3, %0 : memref<128x1024xf32> to memref<128x1024xf32>
    omp.parallel {
      %alloca = memref.alloca() {alignment = 64 : i64} : memref<16x16xf32>
      omp.wsloop {
        omp.loop_nest (%arg4, %arg5) : index = (%c0, %c0) to (%c128, %c1024) step (%c64, %c64) {
          %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg3 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
          %c1024_0 = arith.constant 1024 : index
          %1 = arith.muli %arg4, %c1024_0 : index
          %2 = arith.addi %1, %arg5 : index
          %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%2], sizes: [64, 64], strides: [1024, 1] : memref<f32> to memref<64x64xf32, strided<[1024, 1], offset: ?>>
          %3 = cpuruntime.alloc thread_local() : memref<64x64xf32>
          memref.copy %reinterpret_cast, %3 : memref<64x64xf32, strided<[1024, 1], offset: ?>> to memref<64x64xf32>
          %4 = cpuruntime.alloc thread_local() : memref<64x64xf32>
          memref.copy %reinterpret_cast, %4 : memref<64x64xf32, strided<[1024, 1], offset: ?>> to memref<64x64xf32>
          %5 = cpuruntime.alloc thread_local() : memref<64x16x16xf32>
          scf.for %arg6 = %c0 to %c64 step %c16 {
            %base_buffer_1, %offset_2, %sizes_3:2, %strides_4:2 = memref.extract_strided_metadata %arg1 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
            %c1024_5 = arith.constant 1024 : index
            %6 = arith.muli %arg4, %c1024_5 : index
            %c1024_6 = arith.constant 1024 : index
            %7 = arith.muli %arg6, %c1024_6 : index
            %8 = arith.addi %6, %7 : index
            %reinterpret_cast_7 = memref.reinterpret_cast %base_buffer_1 to offset: [%8], sizes: [16, 64, 16], strides: [1024, 16, 1] : memref<f32> to memref<16x64x16xf32, strided<[1024, 16, 1], offset: ?>>
            scf.for %arg7 = %c0 to %c64 step %c16 {
              %base_buffer_8, %offset_9, %sizes_10:2, %strides_11:2 = memref.extract_strided_metadata %3 : memref<64x64xf32> -> memref<f32>, index, index, index, index, index
              %c64_12 = arith.constant 64 : index
              %9 = arith.muli %arg6, %c64_12 : index
              %10 = arith.addi %9, %arg7 : index
              %reinterpret_cast_13 = memref.reinterpret_cast %base_buffer_8 to offset: [%10], sizes: [16, 16], strides: [64, 1] : memref<f32> to memref<16x16xf32, strided<[64, 1], offset: ?>>
              omp.parallel {
                omp.wsloop {
                  omp.loop_nest (%arg8, %arg9, %arg10) : index = (%c0, %c0, %c0) to (%c64, %c16, %c16) step (%c1, %c1, %c1) {
                    %24 = memref.load %reinterpret_cast_7[%arg9, %arg8, %arg10] : memref<16x64x16xf32, strided<[1024, 16, 1], offset: ?>>
                    memref.store %24, %5[%arg8, %arg9, %arg10] : memref<64x16x16xf32>
                    omp.yield
                  }
                  omp.terminator
                }
                omp.terminator
              }
              omp.parallel {
                omp.wsloop {
                  omp.loop_nest (%arg8, %arg9) : index = (%c0, %c0) to (%c16, %c16) step (%c1, %c1) {
                    memref.store %cst, %reinterpret_cast_13[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                    omp.yield
                  }
                  omp.terminator
                }
                omp.terminator
              }
              %base_buffer_14, %offset_15, %sizes_16:2, %strides_17:2 = memref.extract_strided_metadata %arg0 : memref<1024x1024xf32> -> memref<f32>, index, index, index, index, index
              %11 = arith.addi %arg5, %arg7 : index
              %reinterpret_cast_18 = memref.reinterpret_cast %base_buffer_14 to offset: [%11], sizes: [64, 16, 16], strides: [16384, 1024, 1] : memref<f32> to memref<64x16x16xf32, strided<[16384, 1024, 1], offset: ?>>
              scf.for %arg8 = %c0 to %c64 step %c1 {
                omp.parallel {
                  omp.wsloop {
                    omp.loop_nest (%arg9, %arg10) : index = (%c0, %c0) to (%c16, %c16) step (%c1, %c1) {
                      scf.for %arg11 = %c0 to %c16 step %c1 {
                        %24 = memref.load %5[%arg8, %arg9, %arg11] : memref<64x16x16xf32>
                        %25 = memref.load %reinterpret_cast_18[%arg8, %arg11, %arg10] : memref<64x16x16xf32, strided<[16384, 1024, 1], offset: ?>>
                        %26 = memref.load %reinterpret_cast_13[%arg9, %arg10] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                        %27 = arith.mulf %24, %25 : f32
                        %28 = arith.addf %26, %27 : f32
                        memref.store %28, %reinterpret_cast_13[%arg9, %arg10] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                      }
                      omp.yield
                    }
                    omp.terminator
                  }
                  omp.terminator
                }
              }
              %base_buffer_19, %offset_20, %sizes_21:2, %strides_22:2 = memref.extract_strided_metadata %4 : memref<64x64xf32> -> memref<f32>, index, index, index, index, index
              %c64_23 = arith.constant 64 : index
              %12 = arith.muli %arg6, %c64_23 : index
              %13 = arith.addi %12, %arg7 : index
              %reinterpret_cast_24 = memref.reinterpret_cast %base_buffer_19 to offset: [%13], sizes: [16, 16], strides: [64, 1] : memref<f32> to memref<16x16xf32, strided<[64, 1], offset: ?>>
              %base_buffer_25, %offset_26, %sizes_27:2, %strides_28:2 = memref.extract_strided_metadata %arg2 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
              %c1024_29 = arith.constant 1024 : index
              %14 = arith.muli %arg4, %c1024_29 : index
              %15 = arith.addi %14, %arg5 : index
              %c1024_30 = arith.constant 1024 : index
              %16 = arith.muli %arg6, %c1024_30 : index
              %17 = arith.addi %15, %16 : index
              %18 = arith.addi %17, %arg7 : index
              %reinterpret_cast_31 = memref.reinterpret_cast %base_buffer_25 to offset: [%18], sizes: [16, 16], strides: [1024, 1] : memref<f32> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
              %base_buffer_32, %offset_33, %sizes_34:2, %strides_35:2 = memref.extract_strided_metadata %0 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
              %c1024_36 = arith.constant 1024 : index
              %19 = arith.muli %arg4, %c1024_36 : index
              %20 = arith.addi %19, %arg5 : index
              %c1024_37 = arith.constant 1024 : index
              %21 = arith.muli %arg6, %c1024_37 : index
              %22 = arith.addi %20, %21 : index
              %23 = arith.addi %22, %arg7 : index
              %reinterpret_cast_38 = memref.reinterpret_cast %base_buffer_32 to offset: [%23], sizes: [16, 16], strides: [1024, 1] : memref<f32> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
              omp.parallel {
                omp.wsloop {
                  omp.loop_nest (%arg8, %arg9) : index = (%c0, %c0) to (%c16, %c16) step (%c1, %c1) {
                    %24 = memref.load %reinterpret_cast_13[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                    %25 = memref.load %reinterpret_cast_31[%arg8, %arg9] : memref<16x16xf32, strided<[1024, 1], offset: ?>>
                    %26 = arith.addf %24, %25 : f32
                    memref.store %26, %reinterpret_cast_24[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                    memref.store %cst, %alloca[%arg8, %arg9] : memref<16x16xf32>
                    %27 = memref.load %reinterpret_cast_24[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                    %28 = memref.load %alloca[%arg8, %arg9] : memref<16x16xf32>
                    %29 = arith.cmpf ugt, %27, %28 : f32
                    %30 = arith.select %29, %27, %28 : f32
                    %31 = arith.cmpf uno, %28, %28 : f32
                    %32 = arith.select %31, %28, %30 : f32
                    memref.store %32, %reinterpret_cast_38[%arg8, %arg9] : memref<16x16xf32, strided<[1024, 1], offset: ?>>
                    omp.yield
                  }
                  omp.terminator
                }
                omp.terminator
              }
            }
          }
          cpuruntime.dealloc thread_local %3 : memref<64x64xf32>
          cpuruntime.dealloc thread_local %4 : memref<64x64xf32>
          cpuruntime.dealloc thread_local %5 : memref<64x16x16xf32>
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    memref.copy %0, %arg3 : memref<128x1024xf32> to memref<128x1024xf32>
    cpuruntime.dealloc %0 : memref<128x1024xf32>
    return
  }
}

After this PR

module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 32 : i32>>>} {
  func.func @entry(%arg0: memref<1024x1024xf32>, %arg1: memref<128x1024xf32>, %arg2: memref<128x1024xf32>, %arg3: memref<128x1024xf32>) {
    %c1024 = arith.constant 1024 : index
    %c128 = arith.constant 128 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = cpuruntime.alloc() : memref<128x1024xf32>
    memref.copy %arg3, %0 : memref<128x1024xf32> to memref<128x1024xf32>
    omp.parallel {
      %alloca = memref.alloca() {alignment = 64 : i64} : memref<16x16xf32>
      omp.wsloop {
        omp.loop_nest (%arg4, %arg5) : index = (%c0, %c0) to (%c128, %c1024) step (%c64, %c64) {
          %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg3 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
          %c1024_0 = arith.constant 1024 : index
          %1 = arith.muli %arg4, %c1024_0 : index
          %2 = arith.addi %1, %arg5 : index
          %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%2], sizes: [64, 64], strides: [1024, 1] : memref<f32> to memref<64x64xf32, strided<[1024, 1], offset: ?>>
          %3 = cpuruntime.alloc thread_local() : memref<64x64xf32>
          memref.copy %reinterpret_cast, %3 : memref<64x64xf32, strided<[1024, 1], offset: ?>> to memref<64x64xf32>
          %4 = cpuruntime.alloc thread_local() : memref<64x64xf32>
          memref.copy %reinterpret_cast, %4 : memref<64x64xf32, strided<[1024, 1], offset: ?>> to memref<64x64xf32>
          %5 = cpuruntime.alloc thread_local() : memref<64x16x16xf32>
          scf.for %arg6 = %c0 to %c64 step %c16 {
            %base_buffer_1, %offset_2, %sizes_3:2, %strides_4:2 = memref.extract_strided_metadata %arg1 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
            %c1024_5 = arith.constant 1024 : index
            %6 = arith.muli %arg4, %c1024_5 : index
            %c1024_6 = arith.constant 1024 : index
            %7 = arith.muli %arg6, %c1024_6 : index
            %8 = arith.addi %6, %7 : index
            %reinterpret_cast_7 = memref.reinterpret_cast %base_buffer_1 to offset: [%8], sizes: [16, 64, 16], strides: [1024, 16, 1] : memref<f32> to memref<16x64x16xf32, strided<[1024, 16, 1], offset: ?>>
            scf.for %arg7 = %c0 to %c64 step %c16 {
              %base_buffer_8, %offset_9, %sizes_10:2, %strides_11:2 = memref.extract_strided_metadata %3 : memref<64x64xf32> -> memref<f32>, index, index, index, index, index
              %c64_12 = arith.constant 64 : index
              %9 = arith.muli %arg6, %c64_12 : index
              %10 = arith.addi %9, %arg7 : index
              %reinterpret_cast_13 = memref.reinterpret_cast %base_buffer_8 to offset: [%10], sizes: [16, 16], strides: [64, 1] : memref<f32> to memref<16x16xf32, strided<[64, 1], offset: ?>>
              scf.for %arg8 = %c0 to %c64 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  scf.for %arg10 = %c0 to %c16 step %c1 {
                    %24 = memref.load %reinterpret_cast_7[%arg9, %arg8, %arg10] : memref<16x64x16xf32, strided<[1024, 16, 1], offset: ?>>
                    memref.store %24, %5[%arg8, %arg9, %arg10] : memref<64x16x16xf32>
                  }
                }
              }
              scf.for %arg8 = %c0 to %c16 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  memref.store %cst, %reinterpret_cast_13[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                }
              }
              %base_buffer_14, %offset_15, %sizes_16:2, %strides_17:2 = memref.extract_strided_metadata %arg0 : memref<1024x1024xf32> -> memref<f32>, index, index, index, index, index
              %11 = arith.addi %arg5, %arg7 : index
              %reinterpret_cast_18 = memref.reinterpret_cast %base_buffer_14 to offset: [%11], sizes: [64, 16, 16], strides: [16384, 1024, 1] : memref<f32> to memref<64x16x16xf32, strided<[16384, 1024, 1], offset: ?>>
              scf.for %arg8 = %c0 to %c64 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  scf.for %arg10 = %c0 to %c16 step %c1 {
                    scf.for %arg11 = %c0 to %c16 step %c1 {
                      %24 = memref.load %5[%arg8, %arg9, %arg11] : memref<64x16x16xf32>
                      %25 = memref.load %reinterpret_cast_18[%arg8, %arg11, %arg10] : memref<64x16x16xf32, strided<[16384, 1024, 1], offset: ?>>
                      %26 = memref.load %reinterpret_cast_13[%arg9, %arg10] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                      %27 = arith.mulf %24, %25 : f32
                      %28 = arith.addf %26, %27 : f32
                      memref.store %28, %reinterpret_cast_13[%arg9, %arg10] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                    }
                  }
                }
              }
              %base_buffer_19, %offset_20, %sizes_21:2, %strides_22:2 = memref.extract_strided_metadata %4 : memref<64x64xf32> -> memref<f32>, index, index, index, index, index
              %c64_23 = arith.constant 64 : index
              %12 = arith.muli %arg6, %c64_23 : index
              %13 = arith.addi %12, %arg7 : index
              %reinterpret_cast_24 = memref.reinterpret_cast %base_buffer_19 to offset: [%13], sizes: [16, 16], strides: [64, 1] : memref<f32> to memref<16x16xf32, strided<[64, 1], offset: ?>>
              %base_buffer_25, %offset_26, %sizes_27:2, %strides_28:2 = memref.extract_strided_metadata %arg2 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
              %c1024_29 = arith.constant 1024 : index
              %14 = arith.muli %arg4, %c1024_29 : index
              %15 = arith.addi %14, %arg5 : index
              %c1024_30 = arith.constant 1024 : index
              %16 = arith.muli %arg6, %c1024_30 : index
              %17 = arith.addi %15, %16 : index
              %18 = arith.addi %17, %arg7 : index
              %reinterpret_cast_31 = memref.reinterpret_cast %base_buffer_25 to offset: [%18], sizes: [16, 16], strides: [1024, 1] : memref<f32> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
              scf.for %arg8 = %c0 to %c16 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  %24 = memref.load %reinterpret_cast_13[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                  %25 = memref.load %reinterpret_cast_31[%arg8, %arg9] : memref<16x16xf32, strided<[1024, 1], offset: ?>>
                  %26 = arith.addf %24, %25 : f32
                  memref.store %26, %reinterpret_cast_24[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                }
              }
              scf.for %arg8 = %c0 to %c16 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  memref.store %cst, %alloca[%arg8, %arg9] : memref<16x16xf32>
                }
              }
              %base_buffer_32, %offset_33, %sizes_34:2, %strides_35:2 = memref.extract_strided_metadata %0 : memref<128x1024xf32> -> memref<f32>, index, index, index, index, index
              %c1024_36 = arith.constant 1024 : index
              %19 = arith.muli %arg4, %c1024_36 : index
              %20 = arith.addi %19, %arg5 : index
              %c1024_37 = arith.constant 1024 : index
              %21 = arith.muli %arg6, %c1024_37 : index
              %22 = arith.addi %20, %21 : index
              %23 = arith.addi %22, %arg7 : index
              %reinterpret_cast_38 = memref.reinterpret_cast %base_buffer_32 to offset: [%23], sizes: [16, 16], strides: [1024, 1] : memref<f32> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
              scf.for %arg8 = %c0 to %c16 step %c1 {
                scf.for %arg9 = %c0 to %c16 step %c1 {
                  %24 = memref.load %reinterpret_cast_24[%arg8, %arg9] : memref<16x16xf32, strided<[64, 1], offset: ?>>
                  %25 = memref.load %alloca[%arg8, %arg9] : memref<16x16xf32>
                  %26 = arith.cmpf ugt, %24, %25 : f32
                  %27 = arith.select %26, %24, %25 : f32
                  %28 = arith.cmpf uno, %25, %25 : f32
                  %29 = arith.select %28, %25, %27 : f32
                  memref.store %29, %reinterpret_cast_38[%arg8, %arg9] : memref<16x16xf32, strided<[1024, 1], offset: ?>>
                }
              }
            }
          }
          cpuruntime.dealloc thread_local %3 : memref<64x64xf32>
          cpuruntime.dealloc thread_local %4 : memref<64x64xf32>
          cpuruntime.dealloc thread_local %5 : memref<64x16x16xf32>
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }
    memref.copy %0, %arg3 : memref<128x1024xf32> to memref<128x1024xf32>
    cpuruntime.dealloc %0 : memref<128x1024xf32>
    return
  }
}

@zhczhong zhczhong self-assigned this Aug 29, 2024
@zhczhong zhczhong linked an issue Aug 29, 2024 that may be closed by this pull request
@zhczhong zhczhong merged commit 70e306b into main Aug 29, 2024
@zhczhong zhczhong deleted the zhicong/fix_nested_parallel branch August 29, 2024 06:53
zhczhong pushed a commit that referenced this pull request Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nested omp parallel introduced by linalgToParallelLoopPass

4 participants