diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..428cdb0c1425a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,10 +16,12 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include @@ -288,6 +290,70 @@ struct ConvertStore final : public OpConversionPattern { return success(); } }; + +struct ConvertExtractStridedMetadata final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = extractStridedMetadataOp.getLoc(); + Value source = extractStridedMetadataOp.getSource(); + + MemRefType memrefType = cast(source.getType()); + if (!isMemRefTypeLegalForEmitC(memrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + TypedValue srcArrayValue = + cast>(operands.getSource()); + auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex, + srcArrayValue]() -> emitc::ApplyOp { + int64_t rank = srcArrayValue.getType().getRank(); + llvm::SmallVector indices; + for (int i = 0; i < rank; ++i) { + indices.push_back(zeroIndex); + } + + emitc::SubscriptOp subPtr = rewriter.create( + loc, srcArrayValue, mlir::ValueRange(indices)); + emitc::ApplyOp ptr = rewriter.create( + loc, + emitc::PointerType::get(srcArrayValue.getType().getElementType()), + rewriter.getStringAttr("&"), subPtr); + + return ptr; + }; + + emitc::ApplyOp srcPtr = createPointerFromEmitcArray(); + auto [strides, offset] = memrefType.getStridesAndOffset(); + Value offsetValue = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); + + SmallVector results; + results.push_back(srcPtr); + results.push_back(offsetValue); + + for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) { + Value sizeValue = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(memrefType.getDimSize(i))); + results.push_back(sizeValue); + + Value strideValue = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i])); + results.push_back(strideValue); + } + + rewriter.replaceOp(extractStridedMetadataOp, results); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -320,6 +386,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 2b4eda37903d4..d36eaf3c2673a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -58,3 +58,19 @@ module @globals { return } } + +// ----- + +// CHECK-LABEL: reinterpret_cast +func.func @reinterpret_cast(%arg18: memref<1xi32>) { + // CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32> + // CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index + // CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue + // CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue) -> !emitc.ptr + // CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index + // CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index + // CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index + %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref, index, index, index + return +} +