Skip to content

Commit

Permalink
[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops
Browse files Browse the repository at this point in the history
This patch adds the 'vector.load' and 'vector.store' ops to the Vector
dialect [1]. These operations model *contiguous* vector loads and stores
from/to memory. Their semantics are similar to the 'affine.vector_load' and
'affine.vector_store' counterparts but without the affine constraints. The
most relevant feature is that these new vector operations may perform a vector
load/store on memrefs with a non-vector element type, unlike 'std.load' and
'std.store' ops. This opens the representation to model more generic vector
load/store scenarios: unaligned vector loads/stores, perform scalar and vector
memory access on the same memref, decouple memory allocation constraints from
memory accesses, etc [1]. These operations will also facilitate the progressive
lowering of both Affine vector loads/stores and Vector transfer reads/writes
for those that read/write contiguous slices from/to memory.

In particular, this patch adds the 'vector.load' and 'vector.store' ops to the
Vector dialect, implements their lowering to the LLVM dialect, and changes the
lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector
ops. The lowering of Vector transfer reads/writes will be implemented in the
future, probably as an independent pass. The API of 'vector.maskedload' and
'vector.maskedstore' has also been changed slightly to align it with the
transfer read/write ops and the vector new ops. This will improve reusability
among all these operations. For example, the lowering of 'vector.load',
'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect
is implemented with a single template conversion pattern.

[1] https://llvm.discourse.group/t/memref-type-and-data-layout/

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96185
  • Loading branch information
dcaballe committed Feb 12, 2021
1 parent 98754e2 commit ee66e43
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 111 deletions.
163 changes: 157 additions & 6 deletions mlir/include/mlir/Dialect/Vector/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,156 @@ def Vector_TransferWriteOp :
let hasFolder = 1;
}

def Vector_LoadOp : Vector_Op<"load"> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
The 'vector.load' operation reads an n-D slice of memory into an n-D
vector. It takes a 'base' memref, an index for each memref dimension and a
result vector type as arguments. It returns a value of the result vector
type. The 'base' memref and indices determine the start memory address from
which to read. Each index provides an offset for each memref dimension
based on the element type of the memref. The shape of the result vector
type determines the shape of the slice read from the start memory address.
The elements along each dimension of the slice are strided by the memref
strides. Only memref with default strides are allowed. These constraints
guarantee that elements read along the first dimension of the slice are
contiguous in memory.

The memref element type can be a scalar or a vector type. If the memref
element type is a scalar, it should match the element type of the result
vector. If the memref element type is vector, it should match the result
vector type.

Example 1: 1-D vector load on a scalar memref.
```mlir
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
```

Example 2: 1-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```

Example 3: 2-D vector load on a scalar memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```

Example 4: 2-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```

Representation-wise, the 'vector.load' operation permits out-of-bounds
reads. Support and implementation of out-of-bounds vector loads is
target-specific. No assumptions should be made on the value of elements
loaded out of bounds. Not all targets may support out-of-bounds vector
loads.

Example 5: Potential out-of-bound vector load.
```mlir
%result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
```

Example 6: Explicit out-of-bound vector load.
```mlir
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
```
}];

let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs AnyVector:$result);

let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}

VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];

let assemblyFormat =
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
}

def Vector_StoreOp : Vector_Op<"store"> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
It takes the vector value to be stored, a 'base' memref and an index for
each memref dimension. The 'base' memref and indices determine the start
memory address from which to write. Each index provides an offset for each
memref dimension based on the element type of the memref. The shape of the
vector value to store determines the shape of the slice written from the
start memory address. The elements along each dimension of the slice are
strided by the memref strides. Only memref with default strides are allowed.
These constraints guarantee that elements written along the first dimension
of the slice are contiguous in memory.

The memref element type can be a scalar or a vector type. If the memref
element type is a scalar, it should match the element type of the value
to store. If the memref element type is vector, it should match the type
of the value to store.

Example 1: 1-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
```

Example 2: 1-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```

Example 3: 2-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```

Example 4: 2-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```

Representation-wise, the 'vector.store' operation permits out-of-bounds
writes. Support and implementation of out-of-bounds vector stores are
target-specific. No assumptions should be made on the memory written out of
bounds. Not all targets may support out-of-bounds vector stores.

Example 5: Potential out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
```

Example 6: Explicit out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
```
}];

let arguments = (ins AnyVector:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices);

let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}

VectorType getVectorType() {
return valueToStore().getType().cast<VectorType>();
}
}];

let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
"`:` type($base) `,` type($valueToStore)";
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Expand Down Expand Up @@ -1363,7 +1513,7 @@ def Vector_MaskedLoadOp :
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];
Expand All @@ -1377,7 +1527,7 @@ def Vector_MaskedStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
VectorOfRank<[1]>:$valueToStore)> {

let summary = "stores elements from a vector into memory as defined by a mask vector";

Expand Down Expand Up @@ -1411,12 +1561,13 @@ def Vector_MaskedStoreOp :
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getValueVectorType() {
return value().getType().cast<VectorType>();
VectorType getVectorType() {
return valueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
let assemblyFormat =
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
}

Expand Down
23 changes: 12 additions & 11 deletions mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,9 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
if (!resultOperands)
return failure();

// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
// Build vector.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
*resultOperands);
return success();
}
};
Expand Down Expand Up @@ -625,8 +626,8 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
return failure();

// Build std.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
op.getMemRef(), *maybeExpandedMap);
rewriter.replaceOpWithNewOp<mlir::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
};
Expand Down Expand Up @@ -695,8 +696,8 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
};

/// Apply the affine map from an 'affine.vector_load' operation to its operands,
/// and feed the results to a newly created 'vector.transfer_read' operation
/// (which replaces the original 'affine.vector_load').
/// and feed the results to a newly created 'vector.load' operation (which
/// replaces the original 'affine.vector_load').
class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
public:
using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
Expand All @@ -710,16 +711,16 @@ class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
if (!resultOperands)
return failure();

// Build vector.transfer_read memref[expandedMap.results].
rewriter.replaceOpWithNewOp<TransferReadOp>(
// Build vector.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getVectorType(), op.getMemRef(), *resultOperands);
return success();
}
};

/// Apply the affine map from an 'affine.vector_store' operation to its
/// operands, and feed the results to a newly created 'vector.transfer_write'
/// operation (which replaces the original 'affine.vector_store').
/// operands, and feed the results to a newly created 'vector.store' operation
/// (which replaces the original 'affine.vector_store').
class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
public:
using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
Expand All @@ -733,7 +734,7 @@ class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
if (!maybeExpandedMap)
return failure();

rewriter.replaceOpWithNewOp<TransferWriteOp>(
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
Expand Down
102 changes: 58 additions & 44 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,64 +357,72 @@ class VectorFlatTransposeOpConversion
}
};

/// Conversion pattern for a vector.maskedload.
class VectorMaskedLoadOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
MemRefType memRefType = load.getMemRefType();
/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
vector::LoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
}

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
vector::MaskedLoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
}

// Resolve address.
auto vtype = typeConverter->convertType(load.getResultVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
ptr, align);
}

rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
rewriter.getI32IntegerAttr(align));
return success();
}
};
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
vector::MaskedStoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
}

/// Conversion pattern for a vector.maskedstore.
class VectorMaskedStoreOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
/// vector.maskedstore.
template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
MemRefType memRefType = store.getMemRefType();
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
if (vectorTy.getRank() > 1)
return failure();

auto loc = loadOrStoreOp->getLoc();
auto adaptor = LoadOrStoreOpAdaptor(operands);
MemRefType memRefTy = loadOrStoreOp.getMemRefType();

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();

// Resolve address.
auto vtype = typeConverter->convertType(store.getValueVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
.template cast<VectorType>();
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);

rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
rewriter.getI32IntegerAttr(align));
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
return success();
}
};
Expand Down Expand Up @@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
VectorLoadStoreConversion<vector::LoadOp,
vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
VectorLoadStoreConversion<vector::StoreOp,
vector::StoreOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
Expand Down
Loading

0 comments on commit ee66e43

Please sign in to comment.