Skip to content

Commit

Permalink
[mlir][ArmSME] Fix get_tile_id type in zero lowering
Browse files Browse the repository at this point in the history
The arm_sme.get_tile_id op returns a scalar integer but the arm_sme.zero
op lowering incorrectly uses the element type, which could be
floating-point.

Reviewed By: awarzynski, benmxwl-arm

Differential Revision: https://reviews.llvm.org/D159080
  • Loading branch information
c-rhodes committed Aug 30, 2023
1 parent a4fbc09 commit 834cdc8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();

unsigned tileElementWidth =
zero.getVectorType().getElementType().getIntOrFloatBitWidth();

// Get Tile ID for the `zero` intrinsic.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, zero.getVectorType().getElementType());

auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
loc, rewriter.getIntegerType(tileElementWidth));

// Get the base mask for tile based on the element size.
// The base mask is just the mask to zero the first tile (of a size).
Expand Down
21 changes: 12 additions & 9 deletions mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s

// This test verifies the tile mask operand of the zero intrinsic zeroes
// the correct tiles. Both integer and floating-point datatypes are checked.

// -----

// CHECK-LABEL: zero_za_b
Expand Down Expand Up @@ -32,9 +35,9 @@ func.func @zero_za_h() {
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
"prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
// CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
%zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
"prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
// CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
"prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}

Expand Down Expand Up @@ -65,9 +68,9 @@ func.func @zero_za_s() {
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
"prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
// CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
%zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
"prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
// CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
"prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}

Expand Down Expand Up @@ -122,8 +125,8 @@ func.func @zero_za_d() {
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
"prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
// CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
%zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
"prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
// CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
"prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}

0 comments on commit 834cdc8

Please sign in to comment.