Skip to content

Commit

Permalink
[MLIR] Add async.value type to Async dialect
Browse files Browse the repository at this point in the history
Return values from async regions as !async.value<...>.

Reviewed By: mehdi_amini, csigg

Differential Revision: https://reviews.llvm.org/D88510
  • Loading branch information
ezhulenev committed Sep 30, 2020
1 parent c3193e4 commit 655af65
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 16 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Async/IR/Async.h
Expand Up @@ -22,12 +22,28 @@
namespace mlir {
namespace async {

namespace detail {
struct ValueTypeStorage;
} // namespace detail

/// The token type to represent asynchronous operation completion.
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
public:
using Base::Base;
};

/// The value type to represent values returned from asynchronous operations.
class ValueType
: public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
public:
using Base::Base;

/// Get or create an async ValueType with the provided value type.
static ValueType get(Type valueType);

Type getValueType();
};

} // namespace async
} // namespace mlir

Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
Expand Up @@ -39,4 +39,24 @@ def Async_TokenType : DialectType<AsyncDialect,
}];
}

class Async_ValueType<Type type>
: DialectType<AsyncDialect,
And<[
CPred<"$_self.isa<::mlir::async::ValueType>()">,
SubstLeaves<"$_self",
"$_self.cast<::mlir::async::ValueType>().getValueType()",
type.predicate>
]>, "async value type with " # type.description # " underlying type"> {
let typeDescription = [{
`async.value` represents a value returned by asynchronous operations,
which may or may not be available currently, but will be available at some
point in the future.
}];

Type valueType = type;
}

def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
"async value type">;

#endif // ASYNC_BASE_TD
20 changes: 11 additions & 9 deletions mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
Expand Up @@ -40,24 +40,24 @@ def Async_ExecuteOp : Async_Op<"execute"> {
state). All dependencies must be made explicit with async execute arguments
(`async.token` or `async.value`).

Example:

```mlir
%0 = async.execute {
"compute0"(...)
async.yield
} : !async.token
%done, %values = async.execute {
%0 = "compute0"(...) : !some.type
async.yield %1 : f32
} : !async.token, !async.value<!some.type>

%1 = "compute1"(...)
%1 = "compute1"(...) : !some.type
```
}];

// TODO: Take async.tokens/async.values as arguments.
let arguments = (ins );
let results = (outs Async_TokenType:$done);
let results = (outs Async_TokenType:$done,
Variadic<Async_AnyValueType>:$values);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat = "$body attr-dict `:` type($done)";
let printer = [{ return ::mlir::async::print(p, *this); }];
let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
}

def Async_YieldOp :
Expand All @@ -71,6 +71,8 @@ def Async_YieldOp :
let arguments = (ins Variadic<AnyType>:$operands);

let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";

let verifier = [{ return ::mlir::async::verify(*this); }];
}

#endif // ASYNC_OPS
120 changes: 117 additions & 3 deletions mlir/lib/Dialect/Async/IR/Async.cpp
Expand Up @@ -19,15 +19,16 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::async;
namespace mlir {
namespace async {

void AsyncDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
>();
addTypes<TokenType>();
addTypes<ValueType>();
}

/// Parse a type registered to this dialect.
Expand All @@ -39,16 +40,129 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "token")
return TokenType::get(getContext());

if (keyword == "value") {
Type ty;
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
return Type();
}
return ValueType::get(ty);
}

parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
return Type();
}

/// Print a type registered to this dialect.
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<TokenType>([&](Type) { os << "token"; })
.Case<TokenType>([&](TokenType) { os << "token"; })
.Case<ValueType>([&](ValueType valueTy) {
os << "value<";
os.printType(valueTy.getValueType());
os << '>';
})
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
}

//===----------------------------------------------------------------------===//
/// ValueType
//===----------------------------------------------------------------------===//

namespace detail {

// Storage for `async.value<T>` type, the only member is the wrapped type.
struct ValueTypeStorage : public TypeStorage {
ValueTypeStorage(Type valueType) : valueType(valueType) {}

/// The hash key used for uniquing.
using KeyTy = Type;
bool operator==(const KeyTy &key) const { return key == valueType; }

/// Construction.
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
Type valueType) {
return new (allocator.allocate<ValueTypeStorage>())
ValueTypeStorage(valueType);
}

Type valueType;
};

} // namespace detail

ValueType ValueType::get(Type valueType) {
return Base::get(valueType.getContext(), valueType);
}

Type ValueType::getValueType() { return getImpl()->valueType; }

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(YieldOp op) {
// Get the underlying value types from async values returned from the
// parent `async.execute` operation.
auto executeOp = op.getParentOfType<ExecuteOp>();
auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
return result.getType().cast<ValueType>().getValueType();
});

if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
return op.emitOpError("Operand types do not match the types returned from "
"the parent ExecuteOp");

return success();
}

//===----------------------------------------------------------------------===//
/// ExecuteOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, ExecuteOp op) {
p << "async.execute ";
p.printRegion(op.body());
p.printOptionalAttrDict(op.getAttrs());
p << " : ";
p.printType(op.done().getType());
if (!op.values().empty())
p << ", ";
llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
p.printType(result.getType());
});
}

static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
MLIRContext *ctx = result.getContext();

// Parse asynchronous region.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
/*enableNameShadowing=*/false))
return failure();

// Parse operation attributes.
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs))
return failure();
result.addAttributes(attrs);

// Parse result types.
SmallVector<Type, 4> resultTypes;
if (parser.parseColonTypeList(resultTypes))
return failure();

// First result type must be an async token type.
if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
return failure();
parser.addTypesToList(resultTypes, result.types);

return success();
}

} // namespace async
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
38 changes: 34 additions & 4 deletions mlir/test/Dialect/Async/ops.mlir
@@ -1,16 +1,46 @@
// RUN: mlir-opt %s | FileCheck %s

// CHECK-LABEL: @identity
func @identity(%arg0 : !async.token) -> !async.token {
// CHECK-LABEL: @identity_token
func @identity_token(%arg0 : !async.token) -> !async.token {
// CHECK: return %arg0 : !async.token
return %arg0 : !async.token
}

// CHECK-LABEL: @identity_value
func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
// CHECK: return %arg0 : !async.value<f32>
return %arg0 : !async.value<f32>
}

// CHECK-LABEL: @empty_async_execute
func @empty_async_execute() -> !async.token {
%0 = async.execute {
%done = async.execute {
async.yield
} : !async.token

return %0 : !async.token
// CHECK: return %done : !async.token
return %done : !async.token
}

// CHECK-LABEL: @return_async_value
func @return_async_value() -> !async.value<f32> {
%done, %values = async.execute {
%cst = constant 1.000000e+00 : f32
async.yield %cst : f32
} : !async.token, !async.value<f32>

// CHECK: return %values : !async.value<f32>
return %values : !async.value<f32>
}

// CHECK-LABEL: @return_async_values
func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
%done, %values:2 = async.execute {
%cst1 = constant 1.000000e+00 : f32
%cst2 = constant 2.000000e+00 : f32
async.yield %cst1, %cst2 : f32, f32
} : !async.token, !async.value<f32>, !async.value<f32>

// CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
}

0 comments on commit 655af65

Please sign in to comment.