Skip to content

Commit a5ddd92

Browse files
yijia1212ezhulenev
authored andcommitted
[mlir][async] Allow to call async.execute inside async.func
This change added support of calling async execute inside async.func. Ex. ``` async.func @async_func_call_func() -> !async.token { %token = async.execute { %c0 = arith.constant 0 : index memref.store %arg0, %arg1[%c0] : memref<1xf32> async.yield } async.await %token : !async.token return } ``` Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D141730
1 parent d1fbe2b commit a5ddd92

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,9 @@ void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
840840

841841
target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
842842
[coros](Operation *op) {
843+
auto exec = op->getParentOfType<ExecuteOp>();
843844
auto func = op->getParentOfType<func::FuncOp>();
844-
return coros->find(func) == coros->end();
845+
return exec || coros->find(func) == coros->end();
845846
});
846847
}
847848

mlir/test/Dialect/Async/async-to-async-runtime.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,31 @@ async.func @async_func_await(%arg0: f32, %arg1: !async.value<f32>)
455455
// CHECK-SAME: !async.value<f32>
456456
// CHECK: async.coro.suspend %[[SAVED]]
457457
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
458+
459+
// -----
460+
// Async execute inside async func
461+
462+
// CHECK-LABEL: @execute_in_async_func
463+
async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>)
464+
-> !async.token {
465+
%token = async.execute {
466+
%c0 = arith.constant 0 : index
467+
memref.store %arg0, %arg1[%c0] : memref<1xf32>
468+
async.yield
469+
}
470+
async.await %token : !async.token
471+
return
472+
}
473+
// Call outlind async execute Function
474+
// CHECK: %[[RES:.*]] = call @async_execute_fn(
475+
// CHECK-SAME: %[[VALUE:arg[0-9]+]],
476+
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]
477+
// CHECK-SAME: ) : (f32, memref<1xf32>) -> !async.token
478+
479+
// Function outlined from the async.execute operation.
480+
// CHECK-LABEL: func private @async_execute_fn(
481+
// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
482+
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
483+
// CHECK-SAME: ) -> !async.token
484+
// CHECK: %[[CST:.*]] = arith.constant 0 : index
485+
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]

mlir/test/mlir-cpu-runner/async-func.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ async.func @async_func_passed_memref(%arg0 : !async.value<memref<f32>>) -> !asyn
6464
return
6565
}
6666

67+
async.func @async_execute_in_async_func(%arg0 : !async.value<memref<f32>>) -> !async.token {
68+
%token0 = async.execute {
69+
%unwrapped = async.await %arg0 : !async.value<memref<f32>>
70+
%0 = memref.load %unwrapped[] : memref<f32>
71+
%1 = arith.addf %0, %0 : f32
72+
memref.store %1, %unwrapped[] : memref<f32>
73+
async.yield
74+
}
75+
76+
async.await %token0 : !async.token
77+
return
78+
}
79+
6780

6881
func.func @main() {
6982
%false = arith.constant 0 : i1
@@ -140,6 +153,17 @@ func.func @main() {
140153
// CHECK-NEXT: [0.5]
141154
call @printMemrefF32(%6) : (memref<*xf32>) -> ()
142155

156+
// ------------------------------------------------------------------------ //
157+
// async.execute inside async.func
158+
// ------------------------------------------------------------------------ //
159+
%token4 = async.call @async_execute_in_async_func(%result1) : (!async.value<memref<f32>>) -> !async.token
160+
async.await %token4 : !async.token
161+
162+
// CHECK: Unranked Memref
163+
// CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
164+
// CHECK-NEXT: [1]
165+
call @printMemrefF32(%6) : (memref<*xf32>) -> ()
166+
143167
memref.dealloc %5 : memref<f32>
144168

145169
return

0 commit comments

Comments
 (0)