Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,35 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return success();
}

template <typename OpAdaptor>
static FailureOr<SmallVector<Value>>
extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) {
// At present we only support linear "tiling" as specified in Vulkan, this
// means that texels are assumed to be laid out in memory in a row-major
// order. This allows us to support any memref layout that is a permutation of
// the dimensions. Future work will pass an optional image layout to the
// rewrite pattern so that we can support optimized target specific tilings.
SmallVector<Value> indices = adaptor.getIndices();
AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
if (!map.isPermutation())
return rewriter.notifyMatchFailure(
loadOp,
"Cannot lower memrefs with memory layout which is not a permutation");

// The memrefs layout determines the dimension ordering so we need to follow
// the map to get the ordering of the dimensions/indices.
const unsigned dimCount = map.getNumDims();
SmallVector<Value, 3> coords(dimCount);
for (unsigned dim = 0; dim < dimCount; ++dim)
coords[map.getDimPosition(dim)] = indices[dim];

// We need to reverse the coordinates because the memref layout is slowest to
// fastest moving and the vector coordinates for the image op is fastest to
// slowest moving.
return llvm::to_vector(llvm::reverse(coords));
}

LogicalResult
ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -755,13 +784,17 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

// Build a vector of coordinates or just a scalar index if we have a 1D image.
Value coords;
if (memrefType.getRank() != 1) {
if (memrefType.getRank() == 1) {
coords = adaptor.getIndices()[0];
} else {
FailureOr<SmallVector<Value>> maybeCoords =
extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
if (failed(maybeCoords))
return failure();
auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
adaptor.getIndices().getType()[0]);
coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
adaptor.getIndices());
} else {
coords = adaptor.getIndices()[0];
maybeCoords.value());
}

// Fetch the value out of the image.
Expand Down
Loading