Skip to content

Commit

Permalink
[MLIR][SPIRV] Support two memory access attributes in OpCopyMemory.
Browse files Browse the repository at this point in the history
This commit augments spv.CopyMemory's implementation to support 2 memory
access operands. Hence, more closely following the spec. The following
changes are introduces:

- Customize logic for spv.CopyMemory serialization and deserialization.
- Add 2 additional attributes for source memory access operand.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82710
  • Loading branch information
ergawy authored and antiagainst committed Jul 2, 2020
1 parent 8119a37 commit ef2f46e
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 30 deletions.
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
Expand Up @@ -198,7 +198,7 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
```
copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
storage-class ssa-use
(`[` memory-access `]`)?
(`[` memory-access `]` (`, [` memory-access `]`)?)?
` : ` spirv-element-type
```

Expand All @@ -215,12 +215,16 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
SPV_AnyPtr:$target,
SPV_AnyPtr:$source,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
OptionalAttr<I32Attr>:$alignment,
OptionalAttr<SPV_MemoryAccessAttr>:$source_memory_access,
OptionalAttr<I32Attr>:$source_alignment
);

let results = (outs);

let verifier = [{ return verifyCopyMemory(*this); }];

let autogenSerialization = 0;
}

// -----
Expand Down
96 changes: 75 additions & 21 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -28,7 +28,11 @@
using namespace mlir;

// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kMemoryAccessAttrName[] = "memory_access";
static constexpr const char kSourceMemoryAccessAttrName[] =
"source_memory_access";
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kClusterSize[] = "cluster_size";
Expand Down Expand Up @@ -157,6 +161,8 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
return success();
}

template <const char memoryAccessAttrName[] = kMemoryAccessAttrName,
const char alignmentAttrName[] = kAlignmentAttrName>
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
Expand All @@ -166,7 +172,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
}

spirv::MemoryAccess memoryAccessAttr;
if (parseEnumStrAttr(memoryAccessAttr, parser, state)) {
if (parseEnumStrAttr(memoryAccessAttr, parser, state, memoryAccessAttrName)) {
return failure();
}

Expand All @@ -175,27 +181,41 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
}
return parser.parseRSquare();
}

template <typename MemoryOpTy>
static void
printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs) {
template <typename MemoryOpTy,
const char memoryAccessAttrName[] = kMemoryAccessAttrName,
const char alignmentAttrName[] = kAlignmentAttrName,
bool first = true>
static void printMemoryAccessAttribute(
MemoryOpTy memoryOp, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs,
Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
Optional<llvm::APInt> alignmentAttrValue = None) {
// Print optional memory access attribute.
if (auto memAccess = memoryOp.memory_access()) {
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.memory_access())) {
elidedAttrs.push_back(memoryAccessAttrName);

if (!first) {
printer << ", ";
}

printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

// Print integer alignment attribute.
if (auto alignment = memoryOp.alignment()) {
elidedAttrs.push_back(kAlignmentAttrName);
printer << ", " << alignment;
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.alignment())) {
elidedAttrs.push_back(alignmentAttrName);
printer << ", " << alignment;
}
}
printer << "]";
}
Expand Down Expand Up @@ -243,17 +263,19 @@ static LogicalResult verifyCastOp(Operation *op,
return success();
}

template <typename MemoryOpTy>
template <typename MemoryOpTy,
const char memoryAccessAttrName[] = kMemoryAccessAttrName,
const char alignmentAttrName[] = kAlignmentAttrName>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// ODS checks for attributes values. Just need to verify that if the
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
auto memAccessAttr = op->getAttr(memoryAccessAttrName);
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
if (op->getAttr(kAlignmentAttrName)) {
if (op->getAttr(alignmentAttrName)) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
Expand All @@ -270,11 +292,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
}

if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
if (!op->getAttr(alignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
if (op->getAttr(kAlignmentAttrName)) {
if (op->getAttr(alignmentAttrName)) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
Expand Down Expand Up @@ -2839,6 +2861,10 @@ static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {

SmallVector<StringRef, 4> elidedAttrs;
printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
printMemoryAccessAttribute<decltype(copyMemory), kSourceMemoryAccessAttrName,
kSourceAlignmentAttrName, false>(
copyMemory, printer, elidedAttrs, copyMemory.source_memory_access(),
copyMemory.source_alignment());

printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);

Expand All @@ -2861,9 +2887,23 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
parseMemoryAccessAttributes(parser, state) ||
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
parser.parseType(elementType)) {
parseMemoryAccessAttributes(parser, state)) {
return failure();
}

if (!parser.parseOptionalComma()) {
// Parse 2nd memory access attributes.
if (parseMemoryAccessAttributes<kSourceMemoryAccessAttrName,
kSourceAlignmentAttrName>(parser, state)) {
return failure();
}
}

if (parser.parseColon() || parser.parseType(elementType)) {
return failure();
}

if (parser.parseOptionalAttrDict(state.attributes)) {
return failure();
}

Expand All @@ -2890,7 +2930,21 @@ static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
"both operands must be pointers to the same type");
}

return verifyMemoryAccessAttribute(copyMemory);
if (failed(verifyMemoryAccessAttribute(copyMemory))) {
return failure();
}

// TODO (ergawy): According to the spec:
//
// If two masks are present, the first applies to Target and cannot include
// MakePointerVisible, and the second applies to Source and cannot include
// MakePointerAvailable.
//
// Add such verification here.

return verifyMemoryAccessAttribute<decltype(copyMemory),
kSourceMemoryAccessAttrName,
kSourceAlignmentAttrName>(copyMemory);
}

//===----------------------------------------------------------------------===//
Expand Down
77 changes: 74 additions & 3 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
Expand Up @@ -400,7 +400,8 @@ class Deserializer {
/// Method to deserialize an operation in the SPIR-V dialect that is a mirror
/// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
/// == 1 and autogenSerialization == 1 in ODS.
template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
template <typename OpTy>
LogicalResult processOp(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "unsupported deserialization for ")
<< OpTy::getOperationName() << " op";
}
Expand Down Expand Up @@ -1566,8 +1567,8 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
return success();
}

return emitError(unknownLoc, "unsupported OpConstantNull type: ")
<< resultType;
return emitError(unknownLoc, "unsupported OpConstantNull type: ")
<< resultType;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2509,6 +2510,76 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
SmallVector<Type, 1> resultTypes;
size_t wordIndex = 0;
SmallVector<Value, 4> operands;
SmallVector<NamedAttribute, 4> attributes;

if (wordIndex < words.size()) {
auto arg = getValue(words[wordIndex]);

if (!arg) {
return emitError(unknownLoc, "unknown result <id> : ")
<< words[wordIndex];
}

operands.push_back(arg);
wordIndex++;
}

if (wordIndex < words.size()) {
auto arg = getValue(words[wordIndex]);

if (!arg) {
return emitError(unknownLoc, "unknown result <id> : ")
<< words[wordIndex];
}

operands.push_back(arg);
wordIndex++;
}

bool isAlignedAttr = false;

if (wordIndex < words.size()) {
auto attrValue = words[wordIndex++];
attributes.push_back(opBuilder.getNamedAttr(
"memory_access", opBuilder.getI32IntegerAttr(attrValue)));
isAlignedAttr = (attrValue == 2);
}

if (isAlignedAttr && wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr(
"alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}

if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access",
opBuilder.getI32IntegerAttr(words[wordIndex++])));
}

if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr(
"source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}

if (wordIndex != words.size()) {
return emitError(unknownLoc,
"found more operands than expected when deserializing "
"spirv::CopyMemoryOp, only ")
<< wordIndex << " of " << words.size() << " processed";
}

Location loc = createFileLineColLoc(opBuilder);
opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);

return success();
}

// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS
Expand Down
48 changes: 47 additions & 1 deletion mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -364,7 +364,8 @@ class Serializer {
/// Method to serialize an operation in the SPIR-V dialect that is a mirror of
/// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
/// 1 and autogenSerialization == 1 in ODS.
template <typename OpTy> LogicalResult processOp(OpTy op) {
template <typename OpTy>
LogicalResult processOp(OpTy op) {
return op.emitError("unsupported op serialization");
}

Expand Down Expand Up @@ -1904,6 +1905,51 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
operands);
}

template <>
LogicalResult
Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
SmallVector<uint32_t, 4> operands;
SmallVector<StringRef, 2> elidedAttrs;

for (Value operand : op.getOperation()->getOperands()) {
auto id = getValueID(operand);
assert(id && "use before def!");
operands.push_back(id);
}

if (auto attr = op.getAttr("memory_access")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
}

elidedAttrs.push_back("memory_access");

if (auto attr = op.getAttr("alignment")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
}

elidedAttrs.push_back("alignment");

if (auto attr = op.getAttr("source_memory_access")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
}

elidedAttrs.push_back("source_memory_access");

if (auto attr = op.getAttr("source_alignment")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
}

elidedAttrs.push_back("source_alignment");
emitDebugLine(functionBody, op.getLoc());
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);

return success();
}

// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
// various Serializer::processOp<...>() specializations.
#define GET_SERIALIZATION_FNS
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
Expand Up @@ -93,6 +93,18 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32

// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Volatile"] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32

// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4], ["Volatile"] : f32

// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Aligned", 4] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Aligned", 4] : f32

// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 8], ["Aligned", 4] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 8], ["Aligned", 4] : f32

spv.Return
}
}
Expand Down

0 comments on commit ef2f46e

Please sign in to comment.