Skip to content

Commit

Permalink
[mlir][openacc][NFC] Use oilist in assembly format
Browse files Browse the repository at this point in the history
Use the oilist syntax in assembly format where appropriate.
This makes the dialect format more flexible as an order
is not imposed for the clauses.

Reviewed By: PeteSteinfeld, razvanlupusoru

Differential Revision: https://reviews.llvm.org/D148154
  • Loading branch information
clementval committed Apr 12, 2023
1 parent 9cbdfcd commit 2326480
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 54 deletions.
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-init.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ subroutine acc_init
logical :: ifCondition = .TRUE.

!$acc init
!CHECK: acc.init{{$}}
!CHECK: acc.init{{ *}}{{$}}

!$acc init if(.true.)
!CHECK: [[IF1:%.*]] = arith.constant true
Expand All @@ -27,4 +27,4 @@ subroutine acc_init
!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
!CHECK: acc.init device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}

end subroutine acc_init
end subroutine acc_init
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-shutdown.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ subroutine acc_shutdown
logical :: ifCondition = .TRUE.

!$acc shutdown
!CHECK: acc.shutdown{{$}}
!CHECK: acc.shutdown{{ *}}{{$}}

!$acc shutdown if(.true.)
!CHECK: [[IF1:%.*]] = arith.constant true
Expand All @@ -27,4 +27,4 @@ subroutine acc_shutdown
!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
!CHECK: acc.shutdown device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}

end subroutine acc_shutdown
end subroutine acc_shutdown
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-wait.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subroutine acc_update
logical :: ifCondition = .TRUE.

!$acc wait
!CHECK: acc.wait{{$}}
!CHECK: acc.wait{{ *}}{{$}}

!$acc wait if(.true.)
!CHECK: [[IF1:%.*]] = arith.constant true
Expand All @@ -21,7 +21,7 @@ subroutine acc_update
!$acc wait(1, 2)
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}}
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{ *}}{{$}}

!$acc wait(1) async
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
Expand Down
108 changes: 60 additions & 48 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,23 @@ def OpenACC_DataOp : OpenACC_Op<"data",
}];

let assemblyFormat = [{
( `if` `(` $ifCond^ `)` )?
( `copy` `(` $copyOperands^ `:` type($copyOperands) `)` )?
( `copyin` `(` $copyinOperands^ `:` type($copyinOperands) `)` )?
( `copyin_readonly` `(` $copyinReadonlyOperands^ `:`
type($copyinReadonlyOperands) `)` )?
( `copyout` `(` $copyoutOperands^ `:` type($copyoutOperands) `)` )?
( `copyout_zero` `(` $copyoutZeroOperands^ `:`
type($copyoutZeroOperands) `)` )?
( `create` `(` $createOperands^ `:` type($createOperands) `)` )?
( `create_zero` `(` $createZeroOperands^ `:`
type($createZeroOperands) `)` )?
( `no_create` `(` $noCreateOperands^ `:` type($noCreateOperands) `)` )?
( `present` `(` $presentOperands^ `:` type($presentOperands) `)` )?
( `deviceptr` `(` $deviceptrOperands^ `:` type($deviceptrOperands) `)` )?
( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )?
oilist(
`if` `(` $ifCond `)`
| `copy` `(` $copyOperands `:` type($copyOperands) `)`
| `copyin` `(` $copyinOperands `:` type($copyinOperands) `)`
| `copyin_readonly` `(` $copyinReadonlyOperands `:`
type($copyinReadonlyOperands) `)`
| `copyout` `(` $copyoutOperands `:` type($copyoutOperands) `)`
| `copyout_zero` `(` $copyoutZeroOperands `:`
type($copyoutZeroOperands) `)`
| `create` `(` $createOperands `:` type($createOperands) `)`
| `create_zero` `(` $createZeroOperands `:`
type($createZeroOperands) `)`
| `no_create` `(` $noCreateOperands `:` type($noCreateOperands) `)`
| `present` `(` $presentOperands `:` type($presentOperands) `)`
| `deviceptr` `(` $deviceptrOperands `:` type($deviceptrOperands) `)`
| `attach` `(` $attachOperands `:` type($attachOperands) `)`
)
$region attr-dict-with-keyword
}];
let hasVerifier = 1;
Expand Down Expand Up @@ -272,15 +274,17 @@ def OpenACC_EnterDataOp : OpenACC_Op<"enter_data", [AttrSizedOperandSegments]> {
}];

let assemblyFormat = [{
( `if` `(` $ifCond^ `)` )?
( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )?
( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )?
( `wait` `(` $waitOperands^ `:` type($waitOperands) `)` )?
( `copyin` `(` $copyinOperands^ `:` type($copyinOperands) `)` )?
( `create` `(` $createOperands^ `:` type($createOperands) `)` )?
( `create_zero` `(` $createZeroOperands^ `:`
type($createZeroOperands) `)` )?
( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )?
oilist(
`if` `(` $ifCond `)`
| `async` `(` $asyncOperand `:` type($asyncOperand) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `copyin` `(` $copyinOperands `:` type($copyinOperands) `)`
| `create` `(` $createOperands `:` type($createOperands) `)`
| `create_zero` `(` $createZeroOperands `:`
type($createZeroOperands) `)`
| `attach` `(` $attachOperands `:` type($attachOperands) `)`
)
attr-dict-with-keyword
}];

Expand Down Expand Up @@ -325,13 +329,15 @@ def OpenACC_ExitDataOp : OpenACC_Op<"exit_data", [AttrSizedOperandSegments]> {
}];

let assemblyFormat = [{
( `if` `(` $ifCond^ `)` )?
( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )?
( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )?
( `wait` `(` $waitOperands^ `:` type($waitOperands) `)` )?
( `copyout` `(` $copyoutOperands^ `:` type($copyoutOperands) `)` )?
( `delete` `(` $deleteOperands^ `:` type($deleteOperands) `)` )?
( `detach` `(` $detachOperands^ `:` type($detachOperands) `)` )?
oilist(
`if` `(` $ifCond `)`
| `async` `(` $asyncOperand `:` type($asyncOperand) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `copyout` `(` $copyoutOperands `:` type($copyoutOperands) `)`
| `delete` `(` $deleteOperands `:` type($deleteOperands) `)`
| `detach` `(` $detachOperands `:` type($detachOperands) `)`
)
attr-dict-with-keyword
}];

Expand Down Expand Up @@ -444,9 +450,11 @@ def OpenACC_InitOp : OpenACC_Op<"init", [AttrSizedOperandSegments]> {
Optional<I1>:$ifCond);

let assemblyFormat = [{
( `device_type` `(` $deviceTypeOperands^ `:` type($deviceTypeOperands) `)` )?
( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )?
( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword
oilist(
`device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
| `device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
| `if` `(` $ifCond `)`
) attr-dict-with-keyword
}];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -475,9 +483,10 @@ def OpenACC_ShutdownOp : OpenACC_Op<"shutdown", [AttrSizedOperandSegments]> {
Optional<I1>:$ifCond);

let assemblyFormat = [{
( `device_type` `(` $deviceTypeOperands^ `:` type($deviceTypeOperands) `)` )?
( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )?
( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword
oilist(`device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
|`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
|`if` `(` $ifCond `)`
) attr-dict-with-keyword
}];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -522,14 +531,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
}];

let assemblyFormat = [{
( `if` `(` $ifCond^ `)` )?
( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )?
( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )?
( `device_type` `(` $deviceTypeOperands^ `:`
type($deviceTypeOperands) `)` )?
( `wait` `(` $waitOperands^ `:` type($waitOperands) `)` )?
( `host` `(` $hostOperands^ `:` type($hostOperands) `)` )?
( `device` `(` $deviceOperands^ `:` type($deviceOperands) `)` )?
oilist(
`if` `(` $ifCond `)`
| `async` `(` $asyncOperand `:` type($asyncOperand) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
| `device_type` `(` $deviceTypeOperands `:`
type($deviceTypeOperands) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `host` `(` $hostOperands `:` type($hostOperands) `)`
| `device` `(` $deviceOperands `:` type($deviceOperands) `)`
)
attr-dict-with-keyword
}];

Expand Down Expand Up @@ -564,9 +575,10 @@ def OpenACC_WaitOp : OpenACC_Op<"wait", [AttrSizedOperandSegments]> {

let assemblyFormat = [{
( `(` $waitOperands^ `:` type($waitOperands) `)` )?
( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )?
( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )?
( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword
oilist(`async` `(` $asyncOperand `:` type($asyncOperand) `)`
|`wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
|`if` `(` $ifCond `)`
) attr-dict-with-keyword
}];
let hasVerifier = 1;
}
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/OpenACC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ func.func @testdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf
%ifCond = arith.constant true
acc.data if(%ifCond) present(%a : memref<10xf32>) {
}
acc.data present(%a : memref<10xf32>) if(%ifCond) {
}
acc.data present(%a, %b, %c : memref<10xf32>, memref<10xf32>, memref<10x10xf32>) {
}
acc.data copy(%a, %b, %c : memref<10xf32>, memref<10xf32>, memref<10x10xf32>) {
Expand Down Expand Up @@ -525,6 +527,8 @@ func.func @testdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf
// CHECK: [[IFCOND1:%.*]] = arith.constant true
// CHECK: acc.data if([[IFCOND1]]) present([[ARGA]] : memref<10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.data if([[IFCOND1]]) present([[ARGA]] : memref<10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.data present([[ARGA]], [[ARGB]], [[ARGC]] : memref<10xf32>, memref<10xf32>, memref<10x10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.data copy([[ARGA]], [[ARGB]], [[ARGC]] : memref<10xf32>, memref<10xf32>, memref<10x10xf32>) {
Expand Down Expand Up @@ -565,6 +569,7 @@ func.func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
%ifCond = arith.constant true
acc.update async(%i64Value: i64) host(%a: memref<10xf32>)
acc.update async(%i32Value: i32) host(%a: memref<10xf32>)
acc.update async(%i32Value: i32) host(%a: memref<10xf32>)
acc.update async(%idxValue: index) host(%a: memref<10xf32>)
acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) host(%a: memref<10xf32>)
acc.update if(%ifCond) host(%a: memref<10xf32>)
Expand All @@ -583,6 +588,7 @@ func.func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK: [[IFCOND:%.*]] = arith.constant true
// CHECK: acc.update async([[I64VALUE]] : i64) host([[ARGA]] : memref<10xf32>)
// CHECK: acc.update async([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>)
// CHECK: acc.update async([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>)
// CHECK: acc.update async([[IDXVALUE]] : index) host([[ARGA]] : memref<10xf32>)
// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) host([[ARGA]] : memref<10xf32>)
// CHECK: acc.update if([[IFCOND]]) host([[ARGA]] : memref<10xf32>)
Expand Down Expand Up @@ -610,6 +616,7 @@ acc.wait(%i32Value: i32) async(%idxValue: index)
acc.wait(%i64Value: i64) wait_devnum(%i32Value: i32)
acc.wait attributes {async}
acc.wait(%i64Value: i64) async(%idxValue: index) wait_devnum(%i32Value: i32)
acc.wait(%i64Value: i64) wait_devnum(%i32Value: i32) async(%idxValue: index)
acc.wait if(%ifCond)

// CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64
Expand All @@ -628,6 +635,7 @@ acc.wait if(%ifCond)
// CHECK: acc.wait([[I64VALUE]] : i64) wait_devnum([[I32VALUE]] : i32)
// CHECK: acc.wait attributes {async}
// CHECK: acc.wait([[I64VALUE]] : i64) async([[IDXVALUE]] : index) wait_devnum([[I32VALUE]] : i32)
// CHECK: acc.wait([[I64VALUE]] : i64) async([[IDXVALUE]] : index) wait_devnum([[I32VALUE]] : i32)
// CHECK: acc.wait if([[IFCOND]])

// -----
Expand All @@ -644,6 +652,8 @@ acc.init device_num(%i64Value : i64)
acc.init device_num(%i32Value : i32)
acc.init device_num(%idxValue : index)
acc.init if(%ifCond)
acc.init if(%ifCond) device_num(%idxValue : index)
acc.init device_num(%idxValue : index) if(%ifCond)

// CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64
// CHECK: [[I32VALUE:%.*]] = arith.constant 1 : i32
Expand All @@ -657,6 +667,8 @@ acc.init if(%ifCond)
// CHECK: acc.init device_num([[I32VALUE]] : i32)
// CHECK: acc.init device_num([[IDXVALUE]] : index)
// CHECK: acc.init if([[IFCOND]])
// CHECK: acc.init device_num([[IDXVALUE]] : index) if([[IFCOND]])
// CHECK: acc.init device_num([[IDXVALUE]] : index) if([[IFCOND]])

// -----

Expand All @@ -672,6 +684,8 @@ acc.shutdown device_num(%i64Value : i64)
acc.shutdown device_num(%i32Value : i32)
acc.shutdown device_num(%idxValue : index)
acc.shutdown if(%ifCond)
acc.shutdown if(%ifCond) device_num(%idxValue : index)
acc.shutdown device_num(%idxValue : index) if(%ifCond)

// CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64
// CHECK: [[I32VALUE:%.*]] = arith.constant 1 : i32
Expand All @@ -685,6 +699,8 @@ acc.shutdown if(%ifCond)
// CHECK: acc.shutdown device_num([[I32VALUE]] : i32)
// CHECK: acc.shutdown device_num([[IDXVALUE]] : index)
// CHECK: acc.shutdown if([[IFCOND]])
// CHECK: acc.shutdown device_num([[IDXVALUE]] : index) if([[IFCOND]])
// CHECK: acc.shutdown device_num([[IDXVALUE]] : index) if([[IFCOND]])

// -----

Expand All @@ -701,6 +717,7 @@ func.func @testexitdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
acc.exit_data copyout(%a : memref<10xf32>) attributes {async}
acc.exit_data delete(%a : memref<10xf32>) attributes {wait}
acc.exit_data async(%i64Value : i64) copyout(%a : memref<10xf32>)
acc.exit_data copyout(%a : memref<10xf32>) async(%i64Value : i64)
acc.exit_data if(%ifCond) copyout(%a : memref<10xf32>)
acc.exit_data wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) copyout(%a : memref<10xf32>)

Expand All @@ -719,6 +736,7 @@ func.func @testexitdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK: acc.exit_data copyout([[ARGA]] : memref<10xf32>) attributes {async}
// CHECK: acc.exit_data delete([[ARGA]] : memref<10xf32>) attributes {wait}
// CHECK: acc.exit_data async([[I64VALUE]] : i64) copyout([[ARGA]] : memref<10xf32>)
// CHECK: acc.exit_data async([[I64VALUE]] : i64) copyout([[ARGA]] : memref<10xf32>)
// CHECK: acc.exit_data if([[IFCOND]]) copyout([[ARGA]] : memref<10xf32>)
// CHECK: acc.exit_data wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) copyout([[ARGA]] : memref<10xf32>)
// -----
Expand All @@ -736,6 +754,7 @@ func.func @testenterdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10
acc.enter_data copyin(%a : memref<10xf32>) attributes {async}
acc.enter_data create(%a : memref<10xf32>) attributes {wait}
acc.enter_data async(%i64Value : i64) copyin(%a : memref<10xf32>)
acc.enter_data copyin(%a : memref<10xf32>) async(%i64Value : i64)
acc.enter_data if(%ifCond) copyin(%a : memref<10xf32>)
acc.enter_data wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) copyin(%a : memref<10xf32>)

Expand All @@ -753,5 +772,6 @@ func.func @testenterdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10
// CHECK: acc.enter_data copyin([[ARGA]] : memref<10xf32>) attributes {async}
// CHECK: acc.enter_data create([[ARGA]] : memref<10xf32>) attributes {wait}
// CHECK: acc.enter_data async([[I64VALUE]] : i64) copyin([[ARGA]] : memref<10xf32>)
// CHECK: acc.enter_data async([[I64VALUE]] : i64) copyin([[ARGA]] : memref<10xf32>)
// CHECK: acc.enter_data if([[IFCOND]]) copyin([[ARGA]] : memref<10xf32>)
// CHECK: acc.enter_data wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) copyin([[ARGA]] : memref<10xf32>)

0 comments on commit 2326480

Please sign in to comment.