diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 5180b5614a43f..44dc1bc923a6b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1457,7 +1457,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "operations instead of the alignment of the element type of the " "memref. This flag is intended for use with hardware which requires" "vector alignment, or in application contexts where it is known all " - "vector access are naturally aligned. ">, + "vector access are naturally aligned. If operations have an " + "alignment attribute set, the alignment attribute takes priority " + "over this option ">, Option<"amx", "enable-amx", "bool", /*default=*/"false", "Enables the use of AMX dialect while lowering the vector " diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 1ff7d5dad378e..d9f65ed2cee05 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -247,6 +247,7 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { MemRefType memRefTy = loadOrStoreOp.getMemRefType(); // Resolve alignment. + // Explicit alignment takes priority over use-vector-alignment. unsigned align = loadOrStoreOp.getAlignment().value_or(0); if (!align && failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy, @@ -299,8 +300,10 @@ class VectorGatherOpConversion } // Resolve alignment. - unsigned align; - if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + // Explicit alignment takes priority over use-vector-alignment. + unsigned align = gather.getAlignment().value_or(0); + if (!align && + failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, memRefType, align, useVectorAlignment))) return rewriter.notifyMatchFailure(gather, "could not resolve alignment"); @@ -354,8 +357,10 @@ class VectorScatterOpConversion } // Resolve alignment. - unsigned align; - if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + // Explicit alignment takes priority over use-vector-alignment. + unsigned align = scatter.getAlignment().value_or(0); + if (!align && + failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, memRefType, align, useVectorAlignment))) return rewriter.notifyMatchFailure(scatter, "could not resolve alignment"); @@ -399,8 +404,14 @@ class VectorExpandLoadOpConversion Value ptr = getStridedElementPtr(rewriter, loc, memRefType, adaptor.getBase(), adaptor.getIndices()); + // From: + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // The pointer alignment defaults to 1. + uint64_t alignment = expand.getAlignment().value_or(1); + rewriter.replaceOpWithNewOp( - expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); + expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(), + alignment); return success(); } }; @@ -421,8 +432,13 @@ class VectorCompressStoreOpConversion Value ptr = getStridedElementPtr(rewriter, loc, memRefType, adaptor.getBase(), adaptor.getIndices()); + // From: + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // The pointer alignment defaults to 1. + uint64_t alignment = compress.getAlignment().value_or(1); + rewriter.replaceOpWithNewOp( - compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); + compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment); return success(); } }; diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir index 3fa248656cf3a..12fe3552ce1b7 100644 --- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir +++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir @@ -18,6 +18,18 @@ func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8 // ----- +func.func @load_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + %0 = vector.load %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// ALL-LABEL: func @load_with_alignment_attribute + +// VEC-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32> +// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.store //===----------------------------------------------------------------------===// @@ -35,6 +47,19 @@ func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) { // ----- +func.func @store_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) { + %val = arith.constant dense<11.0> : vector<4xf32> + vector.store %val, %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32> + return +} + +// ALL-LABEL: func @store_with_alignment_attribute + +// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr +// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr + +// ----- + //===----------------------------------------------------------------------===// // vector.maskedload //===----------------------------------------------------------------------===// @@ -52,6 +77,19 @@ func.func @masked_load(%base: memref, %mask: vector<16xi1>, %passthru: ve // ----- +func.func @masked_load_with_alignment_attribute(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0: index + %0 = vector.maskedload %base[%c0], %mask, %passthru {alignment = 8} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} + +// ALL-LABEL: func @masked_load_with_alignment_attribute + +// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> +// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.maskedstore //===----------------------------------------------------------------------===// @@ -69,6 +107,19 @@ func.func @masked_store(%base: memref, %mask: vector<16xi1>, %passthru: v // ----- +func.func @masked_store_with_alignment_attribute(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { + %c0 = arith.constant 0: index + vector.maskedstore %base[%c0], %mask, %passthru {alignment = 8} : memref, vector<16xi1>, vector<16xf32> + return +} + +// ALL-LABEL: func @masked_store_with_alignment_attribute + +// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr +// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr + +// ----- + //===----------------------------------------------------------------------===// // vector.scatter //===----------------------------------------------------------------------===// @@ -86,6 +137,19 @@ func.func @scatter(%base: memref, %index: vector<3xi32>, %mask: vector<3x // ----- +func.func @scatter_with_alignment_attribute(%base: memref, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) { + %0 = arith.constant 0: index + vector.scatter %base[%0][%index], %mask, %value {alignment = 8} : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// ALL-LABEL: func @scatter_with_alignment_attribute + +// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> + +// ----- + //===----------------------------------------------------------------------===// // vector.gather //===----------------------------------------------------------------------===// @@ -100,3 +164,16 @@ func.func @gather(%base: memref, %index: vector<3xi32>, %mask: vector<3xi // VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + +// ----- + +func.func @gather_with_alignment_attribute(%base: memref, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %base[%0][%index], %mask, %passthru {alignment = 8} : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// ALL-LABEL: func @gather_with_alignment_attribute + +// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 9b57b1b6fb4c7..5973c2ba2cbd0 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x // ----- +func.func @gather_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> { + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// CHECK-LABEL: func @gather_with_alignment +// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + +// ----- + //===----------------------------------------------------------------------===// // vector.scatter //===----------------------------------------------------------------------===// @@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4] // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr> +// ----- + +func.func @scatter_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) { + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// CHECK-LABEL: func @scatter_with_alignment +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> + + // ----- //===----------------------------------------------------------------------===// @@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref, %arg1: vector<11xi1>, %a // ----- +func.func @expand_load_op_with_alignment(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> { + %0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref, vector<11xi1>, vector<11xindex> into vector<11xindex> + return %0 : vector<11xindex> +} +// CHECK-LABEL: func @expand_load_op_with_alignment +// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64> + +// ----- + //===----------------------------------------------------------------------===// // vector.compressstore //===----------------------------------------------------------------------===// @@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref, %arg1: vector<11xi1>, // ----- +func.func @compress_store_op_with_alignment(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) { + vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref, vector<11xi1>, vector<11xindex> + return +} +// CHECK-LABEL: func @compress_store_op_with_alignment +// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> () + +// ----- + //===----------------------------------------------------------------------===// // vector.splat //===----------------------------------------------------------------------===//