diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h index 6e72f7c23bdcf..d66d757cb7a8e 100644 --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -151,7 +151,7 @@ class OwningMemRef { AllocFunType allocFun = &::malloc, std::function)> freeFun = [](StridedMemRefType descriptor) { - ::free(descriptor.data); + ::free(descriptor.basePtr); }) : freeFunc(freeFun) { if (shapeAlloc.empty()) diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp index 887db227cfc4b..312b10f28143f 100644 --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -205,7 +205,13 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { }; int64_t shape[] = {k, m}; int64_t shapeAlloc[] = {k + 1, m + 1}; - OwningMemRef a(shape, shapeAlloc, init); + // Use a large alignment to stress the case where the memref data/basePtr are + // disjoint. + int alignment = 8192; + OwningMemRef a(shape, shapeAlloc, init, alignment); + ASSERT_EQ( + (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)), + a->data); ASSERT_EQ(a->sizes[0], k); ASSERT_EQ(a->sizes[1], m); ASSERT_EQ(a->strides[0], m + 1);