diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index 7f721c5d7957..6f73faebdf45 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -2556,8 +2556,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { prev.getType().dyn_cast()) { result = builder.create( loc, pt, prev, std::vector({rhs.getValue(builder)})); - } else { - auto postTy = prev.getType().dyn_cast(); + } else if (auto postTy = prev.getType().dyn_cast()) { mlir::Value rhsV = rhs.getValue(builder); auto prevTy = rhsV.getType().cast(); if (prevTy == postTy) { @@ -2572,6 +2571,18 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { } assert(rhsV.getType() == prev.getType()); result = builder.create(loc, prev, rhsV); + } else if (auto postTy = prev.getType().dyn_cast()) { + mlir::Value rhsV = rhs.getValue(builder); + auto shape = std::vector(postTy.getShape()); + shape[0] = -1; + postTy = mlir::MemRefType::get(shape, postTy.getElementType(), + MemRefLayoutAttrInterface(), + postTy.getMemorySpace()); + auto ptradd = rhsV; + ptradd = castToIndex(loc, ptradd); + result = builder.create(loc, postTy, prev, ptradd); + } else { + assert(false && "Unsupported add assign type"); } lhs.store(builder, result); return lhs; diff --git a/tools/mlir-clang/Test/Verification/memrefaddassign.cpp b/tools/mlir-clang/Test/Verification/memrefaddassign.cpp new file mode 100644 index 000000000000..18b091e42396 --- /dev/null +++ b/tools/mlir-clang/Test/Verification/memrefaddassign.cpp @@ -0,0 +1,12 @@ +// RUN: mlir-clang %s --function=* -c -S | FileCheck %s + +float *foo(float *a) { + a += 32; + return a; +} +// CHECK: func @_Z3fooPf(%arg0: memref) +// CHECK-NEXT %c32 = arith.constant 32 : index +// CHECK-NEXT %0 = "polygeist.subindex"(%arg0, %c32) : (memref, index) -> memref +// CHECK-NEXT return %0 : memref +// CHECK-NEXT } +