Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Add fir.cuda_allocate operation #88586

Merged
merged 3 commits into from
Apr 16, 2024

Conversation

clementval
Copy link
Contributor

Allocatable with cuda device attribute have special semantic for the allocate statement. In flang the allocate statement is lowered to a sequence of runtime call initializing the descriptor and then allocating the descriptor data. This new operation will replace the last runtime call and abstract all the device memory allocation needed.
The lowering patch will follow.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Apr 12, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Allocatable with cuda device attribute have special semantic for the allocate statement. In flang the allocate statement is lowered to a sequence of runtime call initializing the descriptor and then allocating the descriptor data. This new operation will replace the last runtime call and abstract all the device memory allocation needed.
The lowering patch will follow.


Full diff: https://github.com/llvm/llvm-project/pull/88586.diff

5 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+31)
  • (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+1)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+19)
  • (added) flang/test/Fir/cuf-invalid.fir (+50)
  • (added) flang/test/Fir/cuf.mlir (+70)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dff1cdb20cbfef..5f1a097edf09cc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3190,4 +3190,35 @@ def fir_CUDADataTransferOp : fir_Op<"cuda_data_transfer", []> {
   }];
 }
 
+def fir_CUDAAllocateOp : fir_Op<"cuda_allocate", [AttrSizedOperandSegments,
+    MemoryEffects<[MemAlloc<DefaultResource>]>]> {
+  let summary = "Perform the device allocation of data of an allocatable";
+
+  let description = [{
+    The fir.cuda_allocate operation performs the allocation on the device
+    of the data of an allocatable. The descriptor passed to the operation
+    is initialized before with the standard flang runtime calls.
+  }];
+
+  let arguments = (ins AnyRefOrBoxType:$box,
+                       Optional<AnyRefOrBoxType>:$errmsg,
+                       Optional<AnyIntegerType>:$stream,
+                       Optional<AnyRefOrBoxType>:$pinned,
+                       Optional<AnyRefOrBoxType>:$source,
+                       UnitAttr:$hasStat);
+
+  let results = (outs AnyIntegerType:$stat);
+
+  let assemblyFormat = [{
+    $box `:` qualified(type($box))
+    ( `source` `(` $source^ `:` qualified(type($source) )`)` )?
+    ( `errmsg` `(` $errmsg^ `:` type($errmsg) `)` )?
+    ( `stream` `(` $stream^ `:` type($stream) `)` )?
+    ( `pinned` `(` $pinned^ `:` type($pinned) `)` )?
+    attr-dict `->` type($stat)
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 4c6a8064991ab0..3b876e4642da9a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -625,6 +625,7 @@ def AnyRefOrBoxLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
 def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
     fir_HeapType.predicate, fir_PointerType.predicate,
     IsBaseBoxTypePred]>, "any reference or box">;
+def AnyRefOrBoxType : Type<AnyRefOrBox.predicate, "any legal ref or box type">;
 
 def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
     fir_ShapeShiftType.predicate]>, "any legal shape type">;
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8ab74103cb6a80..88710880174d21 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3993,6 +3993,25 @@ mlir::LogicalResult fir::CUDAKernelOp::verify() {
   return mlir::success();
 }
 
+mlir::LogicalResult fir::CUDAAllocateOp::verify() {
+  if (getPinned() && getStream())
+    return emitOpError("pinned and stream cannot appears at the same time");
+  if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect box to be a reference to/or a class or box type value");
+  if (getSource() &&
+      !fir::unwrapRefType(getSource().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect source to be a reference to/or a class or box type value");
+  if (getErrmsg() &&
+      !fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>())
+    return emitOpError(
+        "expect errmsg to be a reference to/or a box type value");
+  if (getErrmsg() && !getHasStat())
+    return emitOpError("expect stat attribute when errmsg is provided");
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
new file mode 100644
index 00000000000000..a5258acffa123e
--- /dev/null
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -0,0 +1,50 @@
+// RUN: fir-opt -split-input-file -verify-diagnostics %s
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  // expected-error@+1{{'fir.cuda_allocate' op pinned and stream cannot appears at the same time}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %1 = fir.alloca i32
+  // expected-error@+1{{'fir.cuda_allocate' op expect box to be a reference to/or a class or box type value}}
+  %2 = fir.cuda_allocate %1 : !fir.ref<i32> -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  // expected-error@+1{{'fir.cuda_allocate' op expect stat attribute when errmsg is provided}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %1 = fir.alloca i32
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  // expected-error@+1{{'fir.cuda_allocate' op expect errmsg to be a reference to/or a box type value}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%1 : !fir.ref<i32>) {hasStat} -> i32
+  return
+}
diff --git a/flang/test/Fir/cuf.mlir b/flang/test/Fir/cuf.mlir
new file mode 100644
index 00000000000000..bf4f866d464bbd
--- /dev/null
+++ b/flang/test/Fir/cuf.mlir
@@ -0,0 +1,70 @@
+// RUN: fir-opt --split-input-file %s | fir-opt --split-input-file | FileCheck %s
+
+// Simple round trip test of operations.
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> stream(%{{.*}} : i32) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub1Eb"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %5:2 = hlfir.declare %1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %12 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> source(%12 : !fir.ref<!fir.box<none>>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> source(%{{.*}} : !fir.ref<!fir.box<none>>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> pinned(%{{.*}} : !fir.ref<i1>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {hasStat} -> i32
+  return
+}

@@ -3190,4 +3190,36 @@ def fir_CUDADataTransferOp : fir_Op<"cuda_data_transfer", []> {
}];
}

def fir_CUDAAllocateOp : fir_Op<"cuda_allocate", [AttrSizedOperandSegments,
MemoryEffects<[MemAlloc<DefaultResource>]>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also have a write effect on the box reference operand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would make sense. I'll add it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I added also effects on errmsg, pinned and source. errmsg, pinned are diagnostic variable that will be written to if present and source will be used for source allocation so we need to read from it.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, Valentin!

@clementval clementval merged commit 8ee7d97 into llvm:main Apr 16, 2024
4 of 5 checks passed
@clementval clementval deleted the cuda_allocate branch April 16, 2024 03:51
Comment on lines +3999 to +4001
if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>())
return emitOpError(
"expect box to be a reference to/or a class or box type value");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not get why it is allowed for the box argument to be a simple fir.box<> (an not a fir.ref<fir.box<>>), how can the fir.box be modified after the allocation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a reference to a box. You are correct that the box needs to be updated after the allocation so it does not make sense to have a simple fir.box<>. I'll make the change and open a PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #88930

clementval added a commit that referenced this pull request Apr 16, 2024
Add the fir.cuda_deallocate operation that perform device deallocation
of data hold by a descriptor. This will replace the call to
AllocatableDeallocate from the runtime.

This is a companion operation to the one added in #88586
clementval added a commit that referenced this pull request Apr 17, 2024
Add MemRead effect on the box operand as the descriptor might be read
when performing the allocation of the data.

Also update the expected type of the box operand to be a reference.
Check in the verifier that this is a reference to a box or class type.

This addresses the comment made post commit on #88586
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants