diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index cd28bd6cf73a5..b1656304fb117 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -184,6 +184,23 @@ def OptimizeAllocationLivenessPass let dependentDialects = ["mlir::memref::MemRefDialect"]; } +def PrintBufferLifetimeStatsPass + : Pass<"print-buffer-lifetime-stats", "func::FuncOp"> { + let summary = "Print buffer lifetime statistics for allocations in a " + "function"; + let description = [{ + This analysis-only pass walks a function, collects memref.alloc / + memref.dealloc pairs via MemoryEffectOpInterface, computes lifetime + intervals based on operation ordering within a block, and prints + statistics: number of tracked allocations, total allocated bytes, peak + live bytes, and the number of non-overlapping allocation pairs that + could potentially share memory. + + The pass does not modify the IR. + }]; + let dependentDialects = ["mlir::memref::MemRefDialect"]; +} + def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> { let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`" "operations"; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp new file mode 100644 index 0000000000000..86f037634ce43 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp @@ -0,0 +1,193 @@ +//===- BufferLifetimeStats.cpp - Buffer lifetime statistics pass ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace bufferization { +#define GEN_PASS_DEF_PRINTBUFFERLIFETIMESTATSPASS +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" +} // namespace bufferization +} // namespace mlir + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +/// Assign a sequential index to each operation in the block. +static DenseMap buildOperationIndex(Block &block) { + DenseMap opIndex; + unsigned idx = 0; + for (Operation &op : block) + opIndex[&op] = idx++; + return opIndex; +} + +/// Find the unique dealloc for `allocResult` in `block`, or nullptr. +static Operation *findDeallocInSameBlock(Value allocResult, Block *block) { + Operation *deallocOp = nullptr; + for (Operation *user : allocResult.getUsers()) { + auto memEffectOp = dyn_cast(user); + if (!memEffectOp) + continue; + SmallVector effects; + memEffectOp.getEffects(effects); + for (const auto &effect : effects) { + if (isa(effect.getEffect()) && + user->getBlock() == block) { + if (deallocOp) + return nullptr; + deallocOp = user; + } + } + } + return deallocOp; +} + +/// Compute the size in bytes for a statically-shaped memref type. +static int64_t getMemRefSizeInBytes(MemRefType type) { + if (!type.hasStaticShape()) + return 0; + int64_t numElements = type.getNumElements(); + unsigned bitsPerElement = type.getElementTypeBitWidth(); + return (numElements * bitsPerElement + 7) / 8; +} + +/// A buffer lifetime interval: [allocIndex, deallocIndex). +struct LifetimeInterval { + Value allocResult; + unsigned allocIndex; + unsigned deallocIndex; + int64_t sizeInBytes; +}; + +/// Check whether two lifetime intervals are non-overlapping. +static bool areNonOverlapping(const LifetimeInterval &a, + const LifetimeInterval &b) { + return a.deallocIndex <= b.allocIndex || b.deallocIndex <= a.allocIndex; +} + +//===----------------------------------------------------------------------===// +// Pass implementation +//===----------------------------------------------------------------------===// + +struct PrintBufferLifetimeStats + : public bufferization::impl::PrintBufferLifetimeStatsPassBase< + PrintBufferLifetimeStats> { +public: + PrintBufferLifetimeStats() = default; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + if (func.isExternal()) + return; + + // We only handle single-block functions for now. + if (!func.getBody().hasOneBlock()) + return; + + Block &entryBlock = func.getBody().front(); + DenseMap opIndex = buildOperationIndex(entryBlock); + SmallVector intervals; + + entryBlock.walk([&](MemoryEffectOpInterface memEffectOp) { + SmallVector effects; + memEffectOp.getEffects(effects); + + for (const MemoryEffects::EffectInstance &effect : effects) { + if (!isa(effect.getEffect())) + continue; + + Value val = effect.getValue(); + if (!val || val.getDefiningOp() != memEffectOp.getOperation()) + continue; + + auto memrefType = dyn_cast(val.getType()); + if (!memrefType) + continue; + + Operation *deallocOp = + findDeallocInSameBlock(val, memEffectOp->getBlock()); + if (!deallocOp) + continue; + + auto allocIt = opIndex.find(memEffectOp.getOperation()); + auto deallocIt = opIndex.find(deallocOp); + if (allocIt == opIndex.end() || deallocIt == opIndex.end()) + continue; + + int64_t sizeBytes = getMemRefSizeInBytes(memrefType); + intervals.push_back( + {val, allocIt->second, deallocIt->second, sizeBytes}); + } + }); + + // Compute statistics. + int64_t totalBytes = 0; + for (const auto &interval : intervals) + totalBytes += interval.sizeInBytes; + + // Compute peak live bytes by sweeping through all time points. + int64_t peakLiveBytes = 0; + if (!intervals.empty()) { + // Collect all unique time points. + SmallVector timePoints; + for (const auto &interval : intervals) { + timePoints.push_back(interval.allocIndex); + timePoints.push_back(interval.deallocIndex); + } + llvm::sort(timePoints); + timePoints.erase(llvm::unique(timePoints), timePoints.end()); + + for (unsigned t : timePoints) { + int64_t liveBytes = 0; + for (const auto &interval : intervals) { + if (interval.allocIndex <= t && t < interval.deallocIndex) + liveBytes += interval.sizeInBytes; + } + peakLiveBytes = std::max(peakLiveBytes, liveBytes); + } + } + + // Count non-overlapping pairs (reuse opportunities). + unsigned nonOverlappingPairs = 0; + for (unsigned i = 0; i < intervals.size(); ++i) + for (unsigned j = i + 1; j < intervals.size(); ++j) + if (areNonOverlapping(intervals[i], intervals[j])) + ++nonOverlappingPairs; + + llvm::outs() << "--- Buffer Lifetime Statistics for '" << func.getSymName() + << "' ---\n"; + llvm::outs() << " Tracked allocations : " << intervals.size() << "\n"; + llvm::outs() << " Total allocated bytes : " << totalBytes << "\n"; + llvm::outs() << " Peak live bytes : " << peakLiveBytes << "\n"; + llvm::outs() << " Non-overlapping pairs : " << nonOverlappingPairs + << "\n"; + + for (const auto &interval : intervals) { + llvm::outs() << " Buffer: " << interval.allocResult.getType() + << " | size=" << interval.sizeInBytes << " | lifetime=[" + << interval.allocIndex << ", " << interval.deallocIndex + << ")\n"; + } + llvm::outs() << "---\n"; + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt index 7c38621be1bb5..27c2bf564f3dd 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms Bufferize.cpp BufferDeallocationSimplification.cpp + BufferLifetimeStats.cpp BufferOptimizations.cpp BufferResultsToOutParams.cpp BufferUtils.cpp diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir new file mode 100644 index 0000000000000..da4d44bcac2b8 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s --print-buffer-lifetime-stats --split-input-file 2>&1 | FileCheck %s + +// CHECK-LABEL: --- Buffer Lifetime Statistics for 'sequential_non_overlapping' --- +// CHECK: Tracked allocations : 2 +// CHECK: Total allocated bytes : 8192 +// CHECK: Peak live bytes : 4096 +// CHECK: Non-overlapping pairs : 1 +// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3) +// CHECK: Buffer: memref<512xf64> | size=4096 | lifetime=[5, 7) +// CHECK: --- + +func.func @sequential_non_overlapping(%arg0: memref<1024xf32>, + %arg1: memref<512xf64>) { + %cst = arith.constant 0.0 : f32 + %a = memref.alloc() : memref<1024xf32> // 1024 * 4 = 4096 bytes + linalg.fill ins(%cst : f32) outs(%a : memref<1024xf32>) + memref.dealloc %a : memref<1024xf32> + + %cst2 = arith.constant 0.0 : f64 + %b = memref.alloc() : memref<512xf64> // 512 * 8 = 4096 bytes + linalg.fill ins(%cst2 : f64) outs(%b : memref<512xf64>) + memref.dealloc %b : memref<512xf64> + return +} + +// ----- + +// CHECK-LABEL: --- Buffer Lifetime Statistics for 'overlapping_lifetimes' --- +// CHECK: Tracked allocations : 2 +// CHECK: Total allocated bytes : 6144 +// CHECK: Peak live bytes : 6144 +// CHECK: Non-overlapping pairs : 0 +// CHECK: Buffer: memref<512xf32> | size=2048 | lifetime=[0, 4) +// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3) +// CHECK: --- + +func.func @overlapping_lifetimes() { + %a = memref.alloc() : memref<512xf32> // 2048 bytes + %b = memref.alloc() : memref<1024xf32> // 4096 bytes + %cst = arith.constant 0.0 : f32 + memref.dealloc %b : memref<1024xf32> + memref.dealloc %a : memref<512xf32> + return +} + +// ----- + +// CHECK-LABEL: --- Buffer Lifetime Statistics for 'three_buffers_mixed' --- +// CHECK: Tracked allocations : 3 +// CHECK: Total allocated bytes : 10240 +// CHECK: Peak live bytes : 8192 +// CHECK: Non-overlapping pairs : 1 +// CHECK: Buffer: memref<512xf32> | size=2048 | lifetime=[0, 2) +// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 5) +// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[3, 6) +// CHECK: --- + +// %a and %b overlap (a=[0,2), b=[1,5)) +// %a and %c don't overlap (a=[0,2), c=[3,6)) +// %b and %c overlap (b=[1,5), c=[3,6)) +// So 1 non-overlapping pair: (%a, %c) +func.func @three_buffers_mixed() { + %a = memref.alloc() : memref<512xf32> // 2048 bytes + %b = memref.alloc() : memref<1024xf32> // 4096 bytes + memref.dealloc %a : memref<512xf32> + %c = memref.alloc() : memref<1024xf32> // 4096 bytes + %cst = arith.constant 0.0 : f32 + memref.dealloc %b : memref<1024xf32> + memref.dealloc %c : memref<1024xf32> + return +} + +// ----- + +// CHECK-LABEL: --- Buffer Lifetime Statistics for 'single_alloc' --- +// CHECK: Tracked allocations : 1 +// CHECK: Total allocated bytes : 256 +// CHECK: Peak live bytes : 256 +// CHECK: Non-overlapping pairs : 0 +// CHECK: Buffer: memref<64xf32> | size=256 | lifetime=[0, 1) +// CHECK: --- + +func.func @single_alloc() { + %a = memref.alloc() : memref<64xf32> // 64 * 4 = 256 bytes + memref.dealloc %a : memref<64xf32> + return +} + +// ----- + +// CHECK-LABEL: --- Buffer Lifetime Statistics for 'no_allocs' --- +// CHECK: Tracked allocations : 0 +// CHECK: Total allocated bytes : 0 +// CHECK: Peak live bytes : 0 +// CHECK: Non-overlapping pairs : 0 +// CHECK: --- + +func.func @no_allocs(%arg0: memref<1024xf32>) { + %cst = arith.constant 0.0 : f32 + linalg.fill ins(%cst : f32) outs(%arg0 : memref<1024xf32>) + return +}