From 7986481d38213fb52506f7ecb223621db34a5660 Mon Sep 17 00:00:00 2001 From: Ryan Kim Date: Sat, 13 Sep 2025 20:16:58 +0900 Subject: [PATCH] [mlir] Fix correct memset range in OwningMemRef zero-init `OwningMemRef` previously called `memset()` on `descriptor.data` with `size + desiredAlignment`, which could write past the allocated region since `data != alignedData`. --- .../include/mlir/ExecutionEngine/MemRefUtils.h | 10 ++++------ mlir/unittests/ExecutionEngine/Invoke.cpp | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h index d66d757cb7a8e..e9471731afe13 100644 --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -164,19 +164,17 @@ class OwningMemRef { int64_t nElements = 1; for (int64_t s : shapeAlloc) nElements *= s; - auto [data, alignedData] = + auto [allocatedPtr, alignedData] = detail::allocAligned(nElements, allocFun, alignment); - descriptor = detail::makeStridedMemRefDescriptor(data, alignedData, - shape, shapeAlloc); + descriptor = detail::makeStridedMemRefDescriptor( + allocatedPtr, alignedData, shape, shapeAlloc); if (init) { for (StridedMemrefIterator it = descriptor.begin(), end = descriptor.end(); it != end; ++it) init(*it, it.getIndices()); } else { - memset(descriptor.data, 0, - nElements * sizeof(T) + - alignment.value_or(detail::nextPowerOf2(sizeof(T)))); + memset(alignedData, 0, nElements * sizeof(T)); } } /// Take ownership of an existing descriptor with a custom deleter. diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp index cdeeca20610f0..3161c7053f7a4 100644 --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -251,6 +251,24 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { EXPECT_EQ((a[{2, 1}]), 42.); } +TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(OwningMemrefZeroInit)) { + constexpr int k = 3; + constexpr int m = 7; + int64_t shape[] = {k, m}; + // Use a large alignment to stress the case where the memref data/basePtr are + // disjoint. + int alignment = 8192; + OwningMemRef a(shape, {}, {}, alignment); + ASSERT_EQ( + (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)), + a->data); + for (int i = 0; i < k; ++i) { + for (int j = 0; j < m; ++j) { + EXPECT_EQ((a[{i, j}]), 0.); + } + } +} + // A helper function that will be called from the JIT static void memrefMultiply(::StridedMemRefType *memref, int32_t coefficient) {