Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1830,14 +1830,14 @@ static void getMultiLevelStrides(const MemRefRegion &region,
}
}

/// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
/// returns the outermost AffineForOp of the copy loop nest. `lbMaps` and
/// `ubMaps` along with `lbOperands` and `ubOperands` hold the lower and upper
/// bound information for the copy loop nest. `fastBufOffsets` contain the
/// expressions to be subtracted out from the respective copy loop iterators in
/// order to index the fast buffer. If `copyOut' is true, generates a copy-out;
/// otherwise a copy-in. Builder `b` should be set to the point the copy nest is
/// inserted.
/// Generates a point-wise copy from/to a non-zero ranked `memref' to/from
/// `fastMemRef' and returns the outermost AffineForOp of the copy loop nest.
/// `lbMaps` and `ubMaps` along with `lbOperands` and `ubOperands` hold the
/// lower and upper bound information for the copy loop nest. `fastBufOffsets`
/// contain the expressions to be subtracted out from the respective copy loop
/// iterators in order to index the fast buffer. If `copyOut' is true, generates
/// a copy-out; otherwise a copy-in. Builder `b` should be set to the point the
/// copy nest is inserted.
//
/// The copy-in nest is generated as follows as an example for a 2-d region:
/// for x = ...
Expand All @@ -1858,6 +1858,8 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
}));

unsigned rank = cast<MemRefType>(memref.getType()).getRank();
// A copy nest can't be generated for 0-ranked memrefs.
assert(rank != 0 && "non-zero rank memref expected");
assert(lbMaps.size() == rank && "wrong number of lb maps");
assert(ubMaps.size() == rank && "wrong number of ub maps");

Expand Down Expand Up @@ -1921,19 +1923,20 @@ emitRemarkForBlock(Block &block) {
return block.getParentOp()->emitRemark();
}

/// Creates a buffer in the faster memory space for the specified memref region;
/// generates a copy from the lower memory space to this one, and replaces all
/// loads/stores in the block range [`begin', `end') of `block' to load/store
/// from that buffer. Returns failure if copies could not be generated due to
/// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
/// in copyPlacementBlock specify the insertion points where the incoming copies
/// and outgoing copies, respectively, should be inserted (the insertion happens
/// right before the insertion point). Since `begin` can itself be invalidated
/// due to the memref rewriting done from this method, the output argument
/// `nBegin` is set to its replacement (set to `begin` if no invalidation
/// happens). Since outgoing copies could have been inserted at `end`, the
/// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
/// size of the fast buffer allocated.
/// Creates a buffer in the faster memory space for the specified memref region
/// (memref has to be non-zero ranked); generates a copy from the lower memory
/// space to this one, and replaces all loads/stores in the block range
/// [`begin', `end') of `block' to load/store from that buffer. Returns failure
/// if copies could not be generated due to yet unimplemented cases.
/// `copyInPlacementStart` and `copyOutPlacementStart` in copyPlacementBlock
/// specify the insertion points where the incoming copies and outgoing copies,
/// respectively, should be inserted (the insertion happens right before the
/// insertion point). Since `begin` can itself be invalidated due to the memref
/// rewriting done from this method, the output argument `nBegin` is set to its
/// replacement (set to `begin` if no invalidation happens). Since outgoing
/// copies could have been inserted at `end`, the output argument `nEnd` is set
/// to the new end. `sizeInBytes` is set to the size of the fast buffer
/// allocated.
static LogicalResult generateCopy(
const MemRefRegion &region, Block *block, Block::iterator begin,
Block::iterator end, Block *copyPlacementBlock,
Expand Down Expand Up @@ -1984,6 +1987,11 @@ static LogicalResult generateCopy(
SmallVector<Value, 4> bufIndices;

unsigned rank = memRefType.getRank();
if (rank == 0) {
LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n");
return failure();
}

SmallVector<int64_t, 4> fastBufferShape;

// Compute the extents of the buffer.
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Dialect/Affine/affine-data-copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,68 @@ func.func @arbitrary_memory_space() {
}
return
}

// CHECK-LABEL: zero_ranked
func.func @zero_ranked(%3:memref<480xi1>) {
%false = arith.constant false
%4 = memref.alloc() {alignment = 128 : i64} : memref<i1>
affine.store %false, %4[] : memref<i1>
%5 = memref.alloc() {alignment = 128 : i64} : memref<i1>
memref.copy %4, %5 : memref<i1> to memref<i1>
affine.for %arg0 = 0 to 480 {
%11 = affine.load %3[%arg0] : memref<480xi1>
%12 = affine.load %5[] : memref<i1>
%13 = arith.cmpi slt, %11, %12 : i1
%14 = arith.select %13, %11, %12 : i1
affine.store %14, %5[] : memref<i1>
}
return
}

// CHECK-LABEL: func @scalar_memref_copy_without_dma
func.func @scalar_memref_copy_without_dma() {
%false = arith.constant false
%4 = memref.alloc() {alignment = 128 : i64} : memref<i1>
affine.store %false, %4[] : memref<i1>

// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[MEMREF:.*]] = memref.alloc() {alignment = 128 : i64} : memref<i1>
// CHECK: affine.store %[[FALSE]], %[[MEMREF]][] : memref<i1>
return
}

// CHECK-LABEL: func @scalar_memref_copy_in_loop
func.func @scalar_memref_copy_in_loop(%3:memref<480xi1>) {
%false = arith.constant false
%4 = memref.alloc() {alignment = 128 : i64} : memref<i1>
affine.store %false, %4[] : memref<i1>
%5 = memref.alloc() {alignment = 128 : i64} : memref<i1>
memref.copy %4, %5 : memref<i1> to memref<i1>
affine.for %arg0 = 0 to 480 {
%11 = affine.load %3[%arg0] : memref<480xi1>
%12 = affine.load %5[] : memref<i1>
%13 = arith.cmpi slt, %11, %12 : i1
%14 = arith.select %13, %11, %12 : i1
affine.store %14, %5[] : memref<i1>
}

// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[MEMREF:.*]] = memref.alloc() {alignment = 128 : i64} : memref<i1>
// CHECK: affine.store %[[FALSE]], %[[MEMREF]][] : memref<i1>
// CHECK: %[[TARGET:.*]] = memref.alloc() {alignment = 128 : i64} : memref<i1>
// CHECK: memref.copy %alloc, %[[TARGET]] : memref<i1> to memref<i1>
// CHECK: %[[FAST_MEMREF:.*]] = memref.alloc() : memref<480xi1>
// CHECK: affine.for %{{.*}} = 0 to 480 {
// CHECK: %{{.*}} = affine.load %arg0[%{{.*}}] : memref<480xi1>
// CHECK: affine.store %{{.*}}, %[[FAST_MEMREF]][%{{.*}}] : memref<480xi1>
// CHECK: }
// CHECK: affine.for %arg1 = 0 to 480 {
// CHECK: %[[L0:.*]] = affine.load %[[FAST_MEMREF]][%arg1] : memref<480xi1>
// CHECK: %[[L1:.*]] = affine.load %[[TARGET]][] : memref<i1>
// CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[L0]], %[[L1]] : i1
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[L0]], %[[L1]] : i1
// CHECK: affine.store %[[SELECT]], %[[TARGET]][] : memref<i1>
// CHECK: }
// CHECK: memref.dealloc %[[FAST_MEMREF]] : memref<480xi1>
return
}