Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add support for lowering masked tile_load ops #70915

Merged
merged 4 commits into from
Nov 8, 2023

Commits on Nov 7, 2023

  1. [mlir][ArmSME] Add support for lowering masked tile_load ops

    This patch extends ArmSMEToSCF to support lowering of masked tile_load
    ops. Only masks created by 'vector.create_mask' are currently supported.
    
    There are two lowerings, one for pad of constant zero and another for
    non-zero pad. For the following example:
    
      %pad = arith.constant 0 : i32
      %num_rows = arith.constant 2 : index
      %num_cols = arith.constant 4 : index
      %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
      %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                              vector<[4]x[4]xi32>
    
    The former (constant non-zero pad) is lowered as follows:
    ---------------------------------------------------------
    
      %tile = arm_sme.zero : vector<[4]x[4]xi32>
      %num_cols = vector.create_mask %c4 : vector<[4]xi1>
      scf.for %slice_idx = %c0 to %num_rows step %c1
        %tile_update = arm_sme.load_tile_slice
          %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
          memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
    
    The tile is zeroed the satisfy the padding and only active rows are
    loaded.
    
    The latter (non-zero pad) is lowered as follows:
    ------------------------------------------------
    
      scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
        %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
        %slice = scf.if %row_is_active -> vector<[4]xf32> {
          %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
            memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
          scf.yield %slice : vector<[4]xf32>
        } else {
          scf.yield %pad_1d : vector<[4]xf32>
        }
        arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
          : vector<[4]xi32> into vector<[4]x[4]xi32>
    
    The scalar pad is broadcast to a 1-D vector and a regular
    'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
    for active rows, with padding specified as a passthru. For non-active
    rows the slice is the 1-D pad. The resulting slice is inserted into the
    tile with 'arm_sme.move_vector_to_tile_slice'.
    c-rhodes committed Nov 7, 2023
    Configuration menu
    Copy the full SHA
    c07bc6a View commit details
    Browse the repository at this point in the history
  2. run clang-format

    c-rhodes committed Nov 7, 2023
    Configuration menu
    Copy the full SHA
    a376455 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    ec829d7 View commit details
    Browse the repository at this point in the history
  4. address comments

    c-rhodes committed Nov 7, 2023
    Configuration menu
    Copy the full SHA
    9a5cdf0 View commit details
    Browse the repository at this point in the history