diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 583cda7ac2810..d693a2c6e5782 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -635,19 +635,9 @@ struct ConvertVectorStore final : OpConversionPattern { return success(); } - // Do the trailing dim for source and destination match? If yes, then the - // corresponding index must be 0. - // FIXME: There's no way to tell for dynamic shapes, so we should bail out. - // However, that makes some tests fail, so we need to audit first. - auto trailingDim = op.getBase().getType().getShape().back(); - bool trailingDimsMatch = - ShapedType::isDynamic(trailingDim) || trailingDim == origElements; - auto stridedMetadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); - // FIXME: ATM, we do not test cases where offsets, sizes, or strides are - // non-zero. As such, this is not needed. OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = @@ -658,10 +648,12 @@ struct ConvertVectorStore final : OpConversionPattern { stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); + // Use the exact intraDataOffset when it can be folded. Dynamic values are + // rejected in this path because a dynamic offset is not necessarily aligned + // to a container element boundary. Callers that can guarantee alignment + // should use assumeAligned. std::optional foldedNumFrontPadElems = - (isDivisibleInSize && trailingDimsMatch) - ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + getConstantIntValue(linearizedInfo.intraDataOffset); if (!foldedNumFrontPadElems) { return rewriter.notifyMatchFailure( diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir index 222e613f5c18a..a359ee68b82c6 100644 --- a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir @@ -45,7 +45,8 @@ func.func @vector_maskedstore_2d_i4(%arg0: index, %value: vector<8xi4>) { func.func @vector_store_2d_i4(%arg0: index, %value: vector<8xi4>) { %0 = memref.alloc() : memref<4x8xi4> - vector.store %value, %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4> + %c0 = arith.constant 0 : index + vector.store %value, %0[%arg0, %c0] : memref<4x8xi4>, vector<8xi4> return } // CHECK-LABEL: func @vector_store_2d_i4( diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic-store.mlir new file mode 100644 index 0000000000000..518d825234502 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic-store.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --verify-diagnostics --split-input-file %s + +// Dynamic sub-byte vector.store offsets cannot be treated as byte-aligned unless +// the caller explicitly opts into the alignment contract. +func.func @vector_store_i2_2d_dynamic_col(%arg0: vector<4xi2>, %idx0: index, + %idx1: index) { + %src = memref.alloc() : memref<3x4xi2> + // expected-error @below {{failed to legalize operation 'vector.store' that was explicitly marked illegal}} + vector.store %arg0, %src[%idx0, %idx1] : memref<3x4xi2>, vector<4xi2> + return +} + +// ----- + +func.func @vector_store_i4_dynamic_memref(%arg0: vector<8xi4>, %dim0: index, + %dim1: index, %idx0: index, + %idx1: index) { + %src = memref.alloc(%dim0, %dim1) : memref + // expected-error @below {{failed to legalize operation 'vector.store' that was explicitly marked illegal}} + vector.store %arg0, %src[%idx0, %idx1] : memref, vector<8xi4> + return +} diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 21f073efc49b2..bec7736b90973 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -562,3 +562,27 @@ func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) { // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8 + +// ----- + +// Regression test for https://github.com/llvm/llvm-project/issues/131528. +// A vector.store with a non-zero constant column index on a 2D memref where the +// trailing dimension matches the vector size must NOT be treated as +// byte-aligned. Instead it must emit two partial (RMW) stores since the +// 4-element i2 vector starting at column 1 crosses a byte boundary. +func.func @vector_store_i2_2d_const_nonzero_col(%arg0: vector<4xi2>) { + %src = memref.alloc() : memref<3x4xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %src[%c0, %c1] : memref<3x4xi2>, vector<4xi2> + return +} + +// CHECK-LABEL: func @vector_store_i2_2d_const_nonzero_col( +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// Emits two partial atomic RMWs: one for byte 0 (elements at positions [1..3]) +// and one for byte 1 (element at position [0]). +// CHECK: memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] +// CHECK: memref.generic_atomic_rmw %[[ALLOC]][{{.+}}] diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 98b1f07ef5fb0..72b355e0fed65 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -455,19 +455,25 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> { func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) { %0 = memref.alloc() : memref<4x8xi8> - vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8> + %c0 = arith.constant 0 : index + vector.store %arg0, %0[%arg1, %c0] :memref<4x8xi8>, vector<8xi8> return } // Expect no conversions, i8 is supported. // CHECK: func @vector_store_i8 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi8> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4x8xi8> -// CHECK: vector.store %[[ARG0]], %[[ALLOC:.+]][%[[ARG1]], %[[ARG2]]] : memref<4x8xi8>, vector<8xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: vector.store %[[ARG0]], %[[ALLOC:.+]][%[[ARG1]], %[[C0]]] : memref<4x8xi8>, vector<8xi8> -// CHECK32-DAG: affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)> +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> // CHECK32: func @vector_store_i8 +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi8> +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<8xi32> -// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]] // CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi8> to vector<2xi32> // CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<8xi32>, vector<2xi32 @@ -475,86 +481,51 @@ func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) { func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) { %0 = memref.alloc() : memref<4x8xi4> - vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi4>, vector<8xi4> + %c0 = arith.constant 0 : index + vector.store %arg0, %0[%arg1, %c0] :memref<4x8xi4>, vector<8xi4> return } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4)> // CHECK: func @vector_store_i4 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8> -// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]] // CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8> // CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8> -// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> // CHECK32: func @vector_store_i4 +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4> +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32> -// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] // CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32> -// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32> +// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[ARG1]]] : memref<4xi32>, vector<1xi32> // ----- func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) { %0 = memref.alloc() : memref<4x8xf4E2M1FN> - vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN> + %c0 = arith.constant 0 : index + vector.store %arg0, %0[%arg1, %c0] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN> return } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4)> // CHECK: func @vector_store_f4 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xf4E2M1FN> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8> -// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]] // CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8> // CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8> -// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> // CHECK32: func @vector_store_f4 +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xf4E2M1FN> +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32> -// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] // CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32> -// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32> - -// ----- - -// FIXME: This example assumes that the store happens at a byte boundary, but -// that's not guaranteed. Below is a counter-example with specific dimensions: -// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4> -// TODO: Revisit post #136797 - -func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { - %0 = memref.alloc(%arg1, %arg2) : memref - vector.store %arg0, %0[%arg3, %arg4] : memref, vector<8xi4> - return -} - -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)> -// CHECK: func @vector_store_i4_dynamic -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index -// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]] -// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref -// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]] -// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8> -// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref, vector<4xi8> - -// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)> -// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)> -// CHECK32: func @vector_store_i4_dynamic -// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4> -// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index -// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index -// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]] -// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref -// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]] -// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32> -// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref, vector<1xi32> +// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[ARG1]]] : memref<4xi32>, vector<1xi32> // -----