Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPT dispatch tracker #96

Open
nirvedhmeshram opened this issue Jan 23, 2024 · 26 comments
Open

OPT dispatch tracker #96

nirvedhmeshram opened this issue Jan 23, 2024 · 26 comments

Comments

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Jan 23, 2024

Dispatch Type Type Shapes Link Running on AIE
Elementwise i64/f32 (8,8) link No
Scan i64 (1,8) link No
Elementwise w/ tensor.extract i64 (1,8,2048) link No
Reduction + Elementwise f32 (8,2048) link No
Reduction + Elementwise f32 (8,2048) link No
Matmul transpose b f32 (8,2048,2048) link No
Elementwise f32 (8,32,64) link No
Elementwise f32 (8,32,64) link No
Batchmatmul transpose b+Elementwise f32 (32,8,8,64) link No
softmax+elemetwise f32 (32,8,8) link No
Batchmatmul +Elementwise(transpose) f32 (32,8,64,8) link No
Matmul transpose b +Elementwise f32 (8,2048,2048) link No
Matmul transpose b + Elementwise f32 (8,8192,2048) link No
Matmul transpose b + Elementwise f32 (8,2048,8192) link No
Matmul transpose b f32 (8,50272,2048) link No
@nirvedhmeshram nirvedhmeshram changed the title OPT dispatch tracker (WIP) OPT dispatch tracker Jan 23, 2024
@MaheshRavishankar
Copy link
Collaborator

Nice... the next step then is to take the matmul_transpose_b operation and batchmatmul operation of that size and run it through the pack based pipeline. @yzhang93 can we start using the c++ pass pipeline for this now?

@yzhang93
Copy link
Contributor

Nice... the next step then is to take the matmul_transpose_b operation and batchmatmul operation of that size and run it through the pack based pipeline. @yzhang93 can we start using the c++ pass pipeline for this now?

Yes, we can start putting pieces together for pack based pipeline. As for the first step, when @erwei-xilinx 's pack/unpack lowering pass got merged, we can add an e2e example with a smaller input sizes and only parallel loops. After this works, we'll add reduction loops and peeling if needed. I have some additional comments in this PR.

@Abhishek-Varma
Copy link
Contributor

Abhishek-Varma commented Jan 31, 2024

CC: @MaheshRavishankar @nirvedhmeshram @yzhang93

I worked on creating a utility which would hopefully be helpful for each of us to individually test out various sizes of Matmul (can be later on extended for other individual ops) - Testing utilities
NOTE: You will have to modify the paths based on your working setup.

Following is my observation based on the recent addition of Pack based pipeline without changing any config and using pack_pipeline_e2e.mlir as the base input IR structure to test the matmul sizes :-

NOTE: For each failure cases I have attached the logs too. Also, the below size combinations are a superset of the ones found in OPT.

  1. Element types:
    Supported: i32
    Not supported:

  2. Matmul sizes (M x N x K x i32) :
    a. Correct results:
    - 8 x 16 x 16
    - 8 x 32 x 16
    - 8 x 64 x 16

    b. Failure at air-copy-to-dma - mlir-print-ir-before-all:

    • 8 x 16 x (8, 32, 64, 2048, 8192)
    • 8 x 32 x (8, 32, 64, 2048, 8192)
    • 8 x 64 x (8, 32, 64, 2048, 8192)
    • 8 x 2048 x (8, 32, 64, 2048, 8192)
    • 8 x 8192 x (8, 32, 64, 2048, 8192)
    • 8 x 50272 x (8, 32, 64, 2048, 8192)

    c. Failure at air-par-to-launch - mlir-print-ir-before-all:

    • 8 x 8 x 8
    • 8 x 8 x 16

    d. Failure at iree-amdaie-decompose-pack-unpack-to-air - mlir-print-ir-before-all:

    • 8 x 8 x (32, 64, 2048, 8192)

    e. Failure at airrt-to-ipu - mlir-print-ir-before-all :

    • 8 x 2048 x 16
    • 8 x 8192 x 16
    • 8 x 50272 x 16

Hope we can start making progress on this front now.

EDIT:
I didn't add equivalent exhaustive test for fp32/i64 element type (although you'd note switching it on via the utility I made is fairly simple).
The reason for not doing so -> the basic shapes which we're supporting in ToM itself is failing for these element types. You'd notice that in the logs I attached. So maybe that needs to be handled first (and added to the lit tests) in tandem with the other classes of errors we see above.
The main intention was to first exhaustively search through the domain and see which "tensor shape" we can support, barring the element type.

@yzhang93
Copy link
Contributor

Thanks @Abhishek-Varma for testing on different input sizes. CC @erwei-xilinx in case he has some thoughts on the failure in AIR.

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Jan 31, 2024

Thanks @Abhishek-Varma , I wanted to add a relatively same test that I also see failure for, the error doesnt seem to be in the passes you mention as well, did you encounter this one?

func.func @matmul_64x32_16xi32_(%lhs: tensor<64x16xi32>, %rhs: tensor<16x32xi32>) -> tensor<64x32xi32> {
  %init_acc = tensor.empty() : tensor<64x32xi32>
  %c0_acc_type = arith.constant 0: i32
  %acc = linalg.fill ins(%c0_acc_type : i32) outs(%init_acc : tensor<64x32xi32>) -> tensor<64x32xi32>
  %result = linalg.matmul ins(%lhs, %rhs: tensor<64x16xi32>, tensor<16x32xi32>) outs(%acc: tensor<64x32xi32>) -> tensor<64x32xi32>
  return %result: tensor<64x32xi32>
}

Error:

 error: 'aiex.ipu.writebd_shimtile' op BD ID exceeds the maximum ID.
loc("/tmp/amdaie_xclbin_fb-dc07cb/module_matmul_64x32_16xi32__dispatch_0_amdaie_xclbin_fb.aiecc.mlir":1:1): error: 'builtin.module' op IPU Instruction pipeline failed
/proj/xcohdstaff6/nmeshram/iree-build-debug/compiler/plugins/external/tests/matmul/amdaie_e2e_matmul_dt_i32_i32_small_amd-aie_xrt_matmuls.mlir:6:13: error: 'builtin.module' op Failed to produce an xclbIN with external tool

@erwei-xilinx
Copy link
Contributor

Have the IREE Transform Dialect tiling script been updated accordingly?

Here's an example of 128x256x128xi32 GEMM. https://gist.gitenterprise.xilinx.com/erweiw/9cdc1d6ead27e40d551963e2328c50b0

Note how tiling factors in the transform dialect are changed accordingly.

@yzhang93
Copy link
Contributor

yzhang93 commented Jan 31, 2024

Have the IREE Transform Dialect tiling script been updated accordingly?

Here's an example of 128x256x128xi32 GEMM. https://gist.gitenterprise.xilinx.com/erweiw/9cdc1d6ead27e40d551963e2328c50b0

Note how tiling factors in the transform dialect are changed accordingly.

Why does the tile size need to be changed according to the input sizes? I know it would affect the performance with different tile sizes and data layout, but it shouldn't cause a failure? Also, we are using the pack pipeline for the tests.

@erwei-xilinx
Copy link
Contributor

erwei-xilinx commented Jan 31, 2024

Why does the tile size need to be changed according to the input sizes? I know it would affect the performance with different tile sizes and data layout, but it shouldn't cause a failure? Also, we are using the pack pipeline for the tests.

AIR infers herd size (no of tiles) and L1 L2 memory buffer sizes from the loop structure generated from the script. At the moment AIR would try to strictlly follow the scf loop structure generated by the Transform script.

@Abhishek-Varma
Copy link
Contributor

Thanks @Abhishek-Varma , I wanted to add a relatively same test that I also see failure for, the error doesnt seem to be in the passes you mention as well, did you encounter this one?

func.func @matmul_64x32_16xi32_(%lhs: tensor<64x16xi32>, %rhs: tensor<16x32xi32>) -> tensor<64x32xi32> {
  %init_acc = tensor.empty() : tensor<64x32xi32>
  %c0_acc_type = arith.constant 0: i32
  %acc = linalg.fill ins(%c0_acc_type : i32) outs(%init_acc : tensor<64x32xi32>) -> tensor<64x32xi32>
  %result = linalg.matmul ins(%lhs, %rhs: tensor<64x16xi32>, tensor<16x32xi32>) outs(%acc: tensor<64x32xi32>) -> tensor<64x32xi32>
  return %result: tensor<64x32xi32>
}

Error:

 error: 'aiex.ipu.writebd_shimtile' op BD ID exceeds the maximum ID.
loc("/tmp/amdaie_xclbin_fb-dc07cb/module_matmul_64x32_16xi32__dispatch_0_amdaie_xclbin_fb.aiecc.mlir":1:1): error: 'builtin.module' op IPU Instruction pipeline failed
/proj/xcohdstaff6/nmeshram/iree-build-debug/compiler/plugins/external/tests/matmul/amdaie_e2e_matmul_dt_i32_i32_small_amd-aie_xrt_matmuls.mlir:6:13: error: 'builtin.module' op Failed to produce an xclbIN with external tool

Hi Nirvedh.

No, I haven't encountered this error.
I was majorly testing with M=8 to have the OPT shapes targeted. Perhaps the above error is found in M=64.

@Abhishek-Varma
Copy link
Contributor

Why does the tile size need to be changed according to the input sizes? I know it would affect the performance with different tile sizes and data layout, but it shouldn't cause a failure? Also, we are using the pack pipeline for the tests.

AIR infers herd size (no of tiles) and L1 L2 memory buffer sizes from the loop structure generated from the script. At the moment AIR would try to strictlly follow the scf loop structure generated by the Transform script.

Hi @erwei-xilinx

Transform dialect script wasn't used for the above tests. We now have an equivalent C++ Pack based pipeline which was used.

@yzhang93
Copy link
Contributor

yzhang93 commented Feb 1, 2024

AIR infers herd size (no of tiles) and L1 L2 memory buffer sizes from the loop structure generated from the script. At the moment AIR would try to strictlly follow the scf loop structure generated by the Transform script.

Thanks @erwei-xilinx! I think now it makes sense to me. So the tests are actually sensitive to the tile sizes and the pack sizes. For these particular cases 8 x 8 x 8 and 8 x 8 x 16 , we should reduce the tile size and pack size for N dimension to 8.

@MaheshRavishankar
Copy link
Collaborator

AIR infers herd size (no of tiles) and L1 L2 memory buffer sizes from the loop structure generated from the script. At the moment AIR would try to strictlly follow the scf loop structure generated by the Transform script.

Thanks @erwei-xilinx! I think now it makes sense to me. So the tests are actually sensitive to the tile sizes and the pack sizes. For these particular cases 8 x 8 x 8 and 8 x 8 x 16 , we should reduce the tile size and pack size for N dimension to 8.

Lets discuss this further... Ideally the tile sizes are fully derived by architecture size. Lets restrict ourselves for multiples of tile sizes for now (cause we have to see what happens with padding semantics of packing). But apart from that they should not matter. Maybe if we set a tile size of min(problem_size, <current_tile_size>) that might be acceptable.

@erwei-xilinx
Copy link
Contributor

This example should help with explaining some of the issues being discussed here: #124 (comment)

@Abhishek-Varma
Copy link
Contributor

Hi @MaheshRavishankar @yzhang93 @nirvedhmeshram

This is based on Vivian's patch : Update tile and pack.

This looks like a better thread for this update since Vivian rightly mentioned in this thread that we're able to generate vmfbs for 8 x (8, 16, 32, 64, 2048) x (8, 16, 32, 64) - thank for this @yzhang93 !

I have confirmed each of the generated vmfbs yield correct results.

I therefore currently working on linalg.batch_matmul (one of the OPT dispatches) and it bailed out quite early since KernelDispatch.cpp wasn't even handling it. I've added a case for ContractionOp and for now using [1, tileM, tileN, 0], [1, 1, 1], [0, 0, 0, 1] as the tile_size to adhere to the linalg.matmul construct - was able to make it through the first level of --iree-amdaie-tile-and-fuse.

@Abhishek-Varma
Copy link
Contributor

Hi @MaheshRavishankar @yzhang93 @nirvedhmeshram

Continuing on yesterday's tryst with linalg.batch_matmul - I was able to make progress and push through this.

The IR after all tile + fuse + pack + transpose for the batch_matmul case looks like :-

func.func @matmul_static_dispatch_0_batch_matmul_32x8x64x8_f32() {
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x8x8xf32>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x8x64xf32>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<32x8x64xf32>>
  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [32, 8, 8], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x8x8xf32>> -> tensor<32x8x8xf32>
  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [32, 8, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x8x64xf32>> -> tensor<32x8x64xf32>
  %5 = tensor.empty() : tensor<32x8x64xf32>
  %6 = scf.forall (%arg0, %arg1, %arg2) = (0, 0, 0) to (32, 8, 64) step (1, 8, 64) shared_outs(%arg3 = %5) -> (tensor<32x8x64xf32>) {
    %extracted_slice = tensor.extract_slice %3[%arg0, %arg1, 0] [1, 8, 8] [1, 1, 1] : tensor<32x8x8xf32> to tensor<1x8x8xf32>
    %extracted_slice_0 = tensor.extract_slice %4[%arg0, 0, %arg2] [1, 8, 64] [1, 1, 1] : tensor<32x8x64xf32> to tensor<1x8x64xf32>
    %extracted_slice_1 = tensor.extract_slice %arg3[%arg0, %arg1, %arg2] [1, 8, 64] [1, 1, 1] : tensor<32x8x64xf32> to tensor<1x8x64xf32>
    %alloc = memref.alloc() : memref<1x1x1x8x8xf32, 1 : i32>
    %7 = bufferization.to_tensor %alloc restrict writable : memref<1x1x1x8x8xf32, 1 : i32>
    %pack = tensor.pack %extracted_slice inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %7 : tensor<1x8x8xf32> -> tensor<1x1x1x8x8xf32>
    %alloc_2 = memref.alloc() : memref<1x1x1x8x64xf32, 1 : i32>
    %8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x1x8x64xf32, 1 : i32>
    %pack_3 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 64] into %8 : tensor<1x8x64xf32> -> tensor<1x1x1x8x64xf32>
    %alloc_4 = memref.alloc() : memref<1x1x1x8x64xf32, 1 : i32>
    %9 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x1x1x8x64xf32, 1 : i32>
    %10 = scf.forall (%arg4, %arg5, %arg6) in (1, 1, 1) shared_outs(%arg7 = %9) -> (tensor<1x1x1x8x64xf32>) {
      %extracted_slice_5 = tensor.extract_slice %pack[%arg4, %arg5, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<1x1x1x8x8xf32>
      %extracted_slice_6 = tensor.extract_slice %pack_3[%arg4, 0, %arg6, 0, 0] [1, 1, 1, 8, 64] [1, 1, 1, 1, 1] : tensor<1x1x1x8x64xf32> to tensor<1x1x1x8x64xf32>
      %extracted_slice_7 = tensor.extract_slice %arg7[%arg4, %arg5, %arg6, 0, 0] [1, 1, 1, 8, 64] [1, 1, 1, 1, 1] : tensor<1x1x1x8x64xf32> to tensor<1x1x1x8x64xf32>
      %alloc_8 = memref.alloc() : memref<1x1x1x1x2x4x8xf32, 2 : i32>
      %11 = bufferization.to_tensor %alloc_8 restrict writable : memref<1x1x1x1x2x4x8xf32, 2 : i32>
      %pack_9 = tensor.pack %extracted_slice_5 outer_dims_perm = [0, 1, 2, 4, 3] inner_dims_pos = [3, 4] inner_tiles = [4, 8] into %11 : tensor<1x1x1x8x8xf32> -> tensor<1x1x1x1x2x4x8xf32>
      %alloc_10 = memref.alloc() : memref<1x1x1x8x1x8x8xf32, 2 : i32>
      %12 = bufferization.to_tensor %alloc_10 restrict writable : memref<1x1x1x8x1x8x8xf32, 2 : i32>
      %pack_11 = tensor.pack %extracted_slice_6 outer_dims_perm = [0, 1, 2, 4, 3] inner_dims_pos = [3, 4] inner_tiles = [8, 8] into %12 : tensor<1x1x1x8x64xf32> -> tensor<1x1x1x8x1x8x8xf32>
      %alloc_12 = memref.alloc() : memref<1x1x1x8x2x4x8xf32, 2 : i32>
      %13 = bufferization.to_tensor %alloc_12 restrict writable : memref<1x1x1x8x2x4x8xf32, 2 : i32>
      %14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<1x1x1x8x2x4x8xf32>) -> tensor<1x1x1x8x2x4x8xf32>
      %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d0, d1, d3, d6, d4, d7, d9)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d0, d3, d2, d5, d6, d9, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d0, d1, d2, d5, d4, d7, d8)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_9, %pack_11 : tensor<1x1x1x1x2x4x8xf32>, tensor<1x1x1x8x1x8x8xf32>) outs(%14 : tensor<1x1x1x8x2x4x8xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 8, 64], [1, 1, 1], [0, 0, 0, 1]]>} {
      ^bb0(%in: f32, %in_14: f32, %out: f32):
        %16 = arith.mulf %in, %in_14 : f32
        %17 = arith.addf %out, %16 : f32
        linalg.yield %17 : f32
      } -> tensor<1x1x1x8x2x4x8xf32>
      %unpack_13 = tensor.unpack %15 outer_dims_perm = [0, 1, 2, 4, 3] inner_dims_pos = [3, 4] inner_tiles = [4, 8] into %extracted_slice_7 : tensor<1x1x1x8x2x4x8xf32> -> tensor<1x1x1x8x64xf32>
      memref.dealloc %alloc_8 : memref<1x1x1x1x2x4x8xf32, 2 : i32>
      memref.dealloc %alloc_10 : memref<1x1x1x8x1x8x8xf32, 2 : i32>
      memref.dealloc %alloc_12 : memref<1x1x1x8x2x4x8xf32, 2 : i32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %unpack_13 into %arg7[%arg4, %arg5, %arg6, 0, 0] [1, 1, 1, 8, 64] [1, 1, 1, 1, 1] : tensor<1x1x1x8x64xf32> into tensor<1x1x1x8x64xf32>
      }
    } {mapping = [#gpu.block<y>, #gpu.block<x>, #gpu.block<z>]}
    %unpack = tensor.unpack %10 inner_dims_pos = [1, 2] inner_tiles = [8, 64] into %extracted_slice_1 : tensor<1x1x1x8x64xf32> -> tensor<1x8x64xf32>
    memref.dealloc %alloc : memref<1x1x1x8x8xf32, 1 : i32>
    memref.dealloc %alloc_2 : memref<1x1x1x8x64xf32, 1 : i32>
    memref.dealloc %alloc_4 : memref<1x1x1x8x64xf32, 1 : i32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %unpack into %arg3[%arg0, %arg1, %arg2] [1, 8, 64] [1, 1, 1] : tensor<1x8x64xf32> into tensor<32x8x64xf32>
    }
  } {mapping = [#gpu.block<y>, #gpu.block<x>, #gpu.block<z>]}
  flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [32, 8, 64], strides = [1, 1, 1] : tensor<32x8x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<32x8x64xf32>>
  return
}

NOTE: You won't see scf.for here - this is expected since it has only 1 iteration and it gets canonicalized away.
Here is the mlir-print-ir-before-all wherein you can see that scf.for is there right before canonicalization.

This is now failing at iree-amdaie-decompose-pack-unpack-to-air (CC: @erwei-xilinx ).

Please let me know the following :-

  1. Is the above way of dealing with "batch dimension" correct? I'm setting the first level tile as [1, Vivian's_TileM, Vivian's_TileN] (where Vivian's_Tile* is from Vivian's patch which got us through the compilation for linalg.matmul shapes)
  2. If yes, is the same supported (in line with the "expected" scf.for* structures) down the pipeline?

I'll accordingly look into iree-amdaie-decompose-pack-unpack-to-air to work for a fix - else we need to discuss and spend time on the "tile sizes + pack config" from the lens of "batch dimension".

@erwei-xilinx
Copy link
Contributor

erwei-xilinx commented Feb 7, 2024

Hi @Abhishek-Varma, can you try and see if this patch fixed the iree-amdaie-decompose-pack-unpack-to-air issue that you saw? Thanks. #131

@erwei-xilinx
Copy link
Contributor

%10 = scf.forall (%arg4, %arg5, %arg6) in (1, 1, 1) shared_outs(%arg7 = %9) -> (tensor<1x1x1x8x64xf32>) {

This is perhaps a feature that we may need to discuss, because currently AIR shall try to convert this 3D forall loop into a physical 2D herd, and currently AIR is not able to automatically figure out a mapping from 3D space to 2D hw...

@Abhishek-Varma
Copy link
Contributor

Hi @erwei-xilinx - Your fix patch seems to work. Thanks!
I'm not getting the error at --iree-amdaie-decompose-pack-unpack-to-air anymore and the IR makes progress until it reaches --air-to-std (AIRLowering) pass.

This is the current state of the lowering.

I've raised the current branch as a WIP PR for review : Enable linalg.batch_matmul PR.

@erwei-xilinx
Copy link
Contributor

Thanks for putting together the WIP PR. Just browsing through the IR, and I noticed there seems to be an issue at air-par-to-herd, where it seems to fail to completely convert the inner (3D) scf.parallel into air.herd. And the same thing happened with air-par-to-launch. After the two passes, it basically generates "launch-segment-launch-segment-herd" loop structure, which is guaranteed to fail in downstream AIR passes...

I'll take a closer look tomorrow.

@yzhang93
Copy link
Contributor

yzhang93 commented Feb 7, 2024

As discussed earlier today, we should also try to run different matmul sizes with pad pipeline. The pad pipeline could be used as a baseline, and @erwei-xilinx is pushing forward the pad pipeline to work with large K dimension this week. @Abhishek-Varma would you like to run the same experiments with pad pipeline?

@Abhishek-Varma
Copy link
Contributor

Hi @yzhang93 - I'm currently taking a look at matmul transpose - just trying to flush out what all would be needed to enable this too just like batch matmul, with the current pack pipeline and accordingly clean up the batch_matmul's work for convergence.
I'll see if I can include the pad pipeline's test in today's bandwidth.

@Abhishek-Varma
Copy link
Contributor

CC: @MaheshRavishankar @yzhang93 @nirvedhmeshram @erwei-xilinx

My updates for the day (besides the WIP PR for linalg.batch_matmul that I raised earlier) :-

  1. Worked on adding initial infra for matmul transpose : Matmul transpose PR.
  2. Tested the matmul transpose on the similar set and confirm that it works for all the sizes which vanilla matmul was working via Vivian's patch.
    i.e., 8 x (8, 16, 32, 64, 2048) x (8, 16, 32, 64) for matmul transpose works! (M x N x K)
    Currently in the above PR you'd see that the lit test is failing - that is because both functions, for some reason, are getting absorbed in the same ModuleOp. Should be a trivial fix - I'll look into it.
  3. I also experimented with the pad pipeline on top of Vivian's patch and confirm that 8 x (8, 16, 32, 64) x (8, 16, 32, 64).
    NOTE: 8 x 2048 x * doesn't seem to be working with the pad based approach albeit it's working with pack based.
    Also, 8 x 8 x 8 works with pad based approach albeit it requires hand-tuned config (as added by Vivian here) for the pack based.

@yzhang93
Copy link
Contributor

yzhang93 commented Feb 7, 2024

  1. I also experimented with the pad pipeline on top of Vivian's patch and confirm that 8 x (8, 16, 32, 64) x (8, 16, 32, 64).
    NOTE: 8 x 2048 x * doesn't seem to be working with the pad based approach albeit it's working with pack based.
    Also, 8 x 8 x 8 works with pad based approach albeit it requires hand-tuned config (as added by Vivian here) for the pack based.

Thanks @Abhishek-Varma for adding the infra support for other required ops. For larger M/N size like 2048, you should probably try to test based on this PR.

@Abhishek-Varma
Copy link
Contributor

Hi @MaheshRavishankar @yzhang93 @nirvedhmeshram @erwei-xilinx

My updates concerning this thread :-

  1. Have updated the mlir-air submodule pull in latest changes for higher M/N sizes.
  2. Tested just with step 1 - no improvement in test case coverage - nor was there any regession.
  3. Used Erwei's branch as it is on top of step 1 - there was regression in test case.
  4. Used Vivian's patch on top of step 3 - we have an improvement in the coverage - 8 x (8, 16, 32, 64, 2048, 8192) x (8, 16, 32, 64) (8 x 8 x 8 still requires hand-tuned config by Vivian). As you can see the improved coverage involves 8 x 8192 x *.
    We still need to handle 8 x 50272 x 16 - mlir-ir-print-before-all
  5. For pad based using step 4 I see a lot of regression compared to what I observed yesterday - @yzhang93 can you please run the same experiment (ToM + Erwei's branch + your commit) and confirm this?

@yzhang93
Copy link
Contributor

yzhang93 commented Feb 8, 2024

@Abhishek-Varma For pack based pipeline to work with large GEMM, it still needs some passes and new config sizes which I'm currently working on. I would suggest stop testing this until I have something running with reduction loops. For pad based pipeline, it's the similar thing, we should change the config sizes accordingly. Note there is another variable that needs to be changed for smaller GEMM sizes, which I haven't fully explored. https://github.com/nod-ai/iree-amd-aie/pull/129/files#diff-42f0d0bb098689f25ee68e8f05ec6c2eaa89ce41a6394d7d99c1d1c912943b38R390

@nirvedhmeshram
Copy link
Contributor Author

@kumardeepakamd this may be useful for you so sharing here, this is the "flat" linalg graph that IREE sees as input
https://gist.github.com/nirvedhmeshram/4797d0505788b2e19bcc3fd7b9d0f1a5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants