diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index c2cd2333f531bd..1c9ac5e84754a4 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1411,22 +1411,24 @@ static LogicalResult generateCopy( auto numElementsSSA = top.create(loc, numElements.getValue()); - SmallVector strideInfos; - getMultiLevelStrides(region, fastBufferShape, &strideInfos); - - // TODO(bondhugula): use all stride levels once DmaStartOp is extended for - // multi-level strides. - if (strideInfos.size() > 1) { - LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n"); - return failure(); - } + Value dmaStride = nullptr; + Value numEltPerDmaStride = nullptr; + if (copyOptions.generateDma) { + SmallVector dmaStrideInfos; + getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos); + + // TODO(bondhugula): use all stride levels once DmaStartOp is extended for + // multi-level strides. + if (dmaStrideInfos.size() > 1) { + LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n"); + return failure(); + } - Value stride = nullptr; - Value numEltPerStride = nullptr; - if (!strideInfos.empty()) { - stride = top.create(loc, strideInfos[0].stride); - numEltPerStride = - top.create(loc, strideInfos[0].numEltPerStride); + if (!dmaStrideInfos.empty()) { + dmaStride = top.create(loc, dmaStrideInfos[0].stride); + numEltPerDmaStride = + top.create(loc, dmaStrideInfos[0].numEltPerStride); + } } // Record the last operation where we want the memref replacement to end. We @@ -1469,13 +1471,13 @@ static LogicalResult generateCopy( b.create(loc, memref, memAffineMap, memIndices, fastMemRef, bufAffineMap, bufIndices, tagMemRef, tagAffineMap, tagIndices, - numElementsSSA, stride, numEltPerStride); + numElementsSSA, dmaStride, numEltPerDmaStride); } else { // DMA non-blocking write from fast buffer to the original memref. auto op = b.create( loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, - stride, numEltPerStride); + dmaStride, numEltPerDmaStride); // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the // end to mark end of block range being processed. if (isCopyOutAtEndOfBlock)