diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 748cb7f28fc8c..ff5845343313c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -58,11 +58,10 @@ #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \ ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \ - mlir::acc::DataOp, mlir::acc::DeclareOp + mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \ mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \ - mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \ - mlir::acc::DeclareExitOp + mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp #define ACC_DATA_CONSTRUCT_OPS \ ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS #define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \ diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index a553653c73479..f2abeab744d17 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -81,7 +81,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { collectVars(op.getDataClauseOperands(), values, hostToDevice); if constexpr (!std::is_same_v && !std::is_same_v && - !std::is_same_v) { + !std::is_same_v && + !std::is_same_v) { collectVars(op.getReductionOperands(), values, hostToDevice); collectVars(op.getPrivateOperands(), values, hostToDevice); collectVars(op.getFirstprivateOperands(), values, hostToDevice); @@ -122,6 +123,8 @@ class LegalizeDataValuesInRegion collectAndReplaceInRegion(dataOp, replaceHostVsDevice); } else if (auto declareOp = dyn_cast(*op)) { collectAndReplaceInRegion(declareOp, replaceHostVsDevice); + } else if (auto hostDataOp = dyn_cast(*op)) { + collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice); } else { llvm_unreachable("unsupported acc region op"); } diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index baa72ae416c92..9461225e9a7e0 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -102,7 +102,7 @@ func.func @test(%a: memref<10xf32>) { return } -// CHECK: func.func @test +// CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) // CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> // CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) { @@ -140,7 +140,7 @@ func.func @test(%a: memref<10xf32>) { return } -// CHECK: func.func @test +// CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> // CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { @@ -178,7 +178,7 @@ func.func @test(%a: memref<10xf32>) { return } -// CHECK: func.func @test +// CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> // CHECK: acc.parallel { @@ -216,7 +216,7 @@ func.func @test(%a: memref<10xf32>) { return } -// CHECK: func.func @test +// CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> // CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { @@ -226,3 +226,23 @@ func.func @test(%a: memref<10xf32>) { // CHECK: } // CHECK: acc.yield // CHECK: } + +// ----- + +func.func @test(%a: memref<10xf32>) { + %devptr = acc.use_device varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> + acc.host_data dataOperands(%devptr : memref<10xf32>) { + func.call @foo(%a) : (memref<10xf32>) -> () + acc.terminator + } + return +} +func.func private @foo(memref<10xf32>) + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[USE_DEVICE:.*]] = acc.use_device varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> +// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) { +// DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> () +// CHECK: acc.terminator +// CHECK: } \ No newline at end of file