diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 5571b5866a01cb..c59cbdab8fb746 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -212,7 +212,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) { return LLVM::LLVMPointerType::get(converted); } - // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, @@ -525,10 +524,11 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); - assert(offset != MemRefType::getDynamicStrideOrOffset() && + assert(!MemRefType::isDynamicStrideOrOffset(offset) && "expected static offset"); - assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && - "expected static strides"); + assert(!llvm::any_of(strides, [](int64_t stride) { + return MemRefType::isDynamicStrideOrOffset(stride); + }) && "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); @@ -1044,14 +1044,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr( Value index; if (offset != 0) // Skip if offset is zero. - index = offset == MemRefType::getDynamicStrideOrOffset() + index = MemRefType::isDynamicStrideOrOffset(offset) ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. - Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() + Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); increment = rewriter.create(loc, increment, stride); @@ -3308,7 +3308,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); - // Copy the buffer pointer from the old descriptor to the new one. + // Copy the aligned pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, @@ -3487,7 +3487,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); - if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) + if (!MemRefType::isDynamicStrideOrOffset(strides[idx])) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index d5dac489ab51ff..7226d89f835f6a 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3078,29 +3078,32 @@ enum SubViewVerificationResult { /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. static SubViewVerificationResult isRankReducedType(Type originalType, - Type reducedType) { - if (originalType == reducedType) + Type candidateReducedType) { + if (originalType == candidateReducedType) return SubViewVerificationResult::Success; if (!originalType.isa() && !originalType.isa()) return SubViewVerificationResult::Success; if (originalType.isa() && - !reducedType.isa()) + !candidateReducedType.isa()) return SubViewVerificationResult::Success; - if (originalType.isa() && !reducedType.isa()) + if (originalType.isa() && !candidateReducedType.isa()) return SubViewVerificationResult::Success; ShapedType originalShapedType = originalType.cast(); - ShapedType reducedShapedType = reducedType.cast(); + ShapedType candidateReducedShapedType = + candidateReducedType.cast(); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); - ArrayRef reducedShape = reducedShapedType.getShape(); + ArrayRef candidateReducedShape = + candidateReducedShapedType.getShape(); unsigned originalRank = originalShape.size(), - reducedRank = reducedShape.size(); - if (reducedRank > originalRank) + candidateReducedRank = candidateReducedShape.size(); + if (candidateReducedRank > originalRank) return SubViewVerificationResult::RankTooLarge; - auto optionalMask = computeRankReductionMask(originalShape, reducedShape); + auto optionalMask = + computeRankReductionMask(originalShape, candidateReducedShape); // Sizes cannot be matched in case empty vector is returned. if (!optionalMask.hasValue()) @@ -3112,34 +3115,43 @@ static SubViewVerificationResult isRankReducedType(Type originalType, // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); - MemRefType reduced = reducedType.cast(); + MemRefType candidateReduced = candidateReducedType.cast(); MLIRContext *c = original.getContext(); - int64_t originalOffset, reducedOffset; - SmallVector originalStrides, reducedStrides, keepStrides; + int64_t originalOffset, candidateReducedOffset; + SmallVector originalStrides, candidateReducedStrides, keepStrides; SmallVector keepMask = optionalMask.getValue(); getStridesAndOffset(original, originalStrides, originalOffset); - getStridesAndOffset(reduced, reducedStrides, reducedOffset); + getStridesAndOffset(candidateReduced, candidateReducedStrides, + candidateReducedOffset); // Filter strides based on the mask and check that they are the same - // as reduced ones. - unsigned reducedIdx = 0; + // as candidateReduced ones. + unsigned candidateReducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { if (keepMask[originalIdx]) { - if (originalStrides[originalIdx] != reducedStrides[reducedIdx++]) + if (originalStrides[originalIdx] != + candidateReducedStrides[candidateReducedIdx++]) return SubViewVerificationResult::StrideMismatch; keepStrides.push_back(originalStrides[originalIdx]); } } - if (original.getElementType() != reduced.getElementType()) + if (original.getElementType() != candidateReduced.getElementType()) return SubViewVerificationResult::ElemTypeMismatch; - if (original.getMemorySpace() != reduced.getMemorySpace()) + if (original.getMemorySpace() != candidateReduced.getMemorySpace()) return SubViewVerificationResult::MemSpaceMismatch; + // reducedMap is obtained by projecting away the dimensions inferred from + // matching the 1's positions in candidateReducedType. auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c); - if (!reduced.getAffineMaps().empty() && - reducedMap != reduced.getAffineMaps().front()) + + MemRefType expectedReducedType = MemRefType::get( + candidateReduced.getShape(), candidateReduced.getElementType(), + reducedMap, candidateReduced.getMemorySpace()); + expectedReducedType = canonicalizeStridedLayout(expectedReducedType); + + if (expectedReducedType != canonicalizeStridedLayout(candidateReduced)) return SubViewVerificationResult::AffineMapMismatch; return SubViewVerificationResult::Success; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 23553a8483de42..14144077ef4ca3 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -745,12 +745,20 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) return t; + // Corner-case for 0-D affine maps. + auto m = affineMaps[0]; + if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { + if (auto cst = m.getResult(0).dyn_cast()) + if (cst.getValue() == 0) + return MemRefType::Builder(t).setAffineMaps({}); + return t; + } + // If the canonical strided layout for the sizes of `t` is equal to the // simplified layout of `t` we can just return an empty layout. Otherwise, // just simplify the existing layout. AffineExpr expr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - auto m = affineMaps[0]; auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index a12610722972e4..595a950781b102 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1011,6 +1011,16 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) // ----- +func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>' or a rank-reduced version. (mismatch of result sizes)}} + %1 = subview %0[0, 2, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to memref<16x4xf32> + return +} + +// ----- + func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}} %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref