diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index a47f70b168066..c60eb5cc620a7 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -445,7 +445,10 @@ class OpenACC_DataEntryOp(attr); if (deviceTypeAttr.getValue() == deviceType) return true; @@ -817,7 +820,10 @@ class OpenACC_DataExitOp(attr); if (deviceTypeAttr.getValue() == deviceType) return true; diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index cfb8aa767b6f8..aa16421cbec51 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) { testAsyncOnly(b, context, loc, dtypes); } +template +void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + + TypedValue varPtr = + cast>(varPtrOp->getResult()); + OwningOpRef op = b.create(loc, varPtr, + /*structured=*/true, /*implicit=*/true); + + EXPECT_FALSE(op->hasAsyncOnly()); + for (auto d : dtypes) + EXPECT_FALSE(op->hasAsyncOnly(d)); + + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op->hasAsyncOnly()); + EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None)); + op->removeAsyncOnlyAttr(); + + auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host); + op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost})); + EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host)); + EXPECT_FALSE(op->hasAsyncOnly()); + op->removeAsyncOnlyAttr(); + + auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star); + op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar})); + EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star)); + EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host)); + EXPECT_FALSE(op->hasAsyncOnly()); + + op->removeAsyncOnlyAttr(); +} + +TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) { + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); + testAsyncOnlyDataEntry(b, context, loc, dtypes); +} + template void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector &dtypes) { @@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) { testAsyncValue(b, context, loc, dtypes); } +template +void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + + TypedValue varPtr = + cast>(varPtrOp->getResult()); + OwningOpRef op = b.create(loc, varPtr, + /*structured=*/true, /*implicit=*/true); + + mlir::Value empty; + EXPECT_EQ(op->getAsyncValue(), empty); + for (auto d : dtypes) + EXPECT_EQ(op->getAsyncValue(d), empty); + + OwningOpRef val = + b.create(loc, 1); + auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); + op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); + op->getAsyncOperandsMutable().assign(val->getResult()); + EXPECT_EQ(op->getAsyncValue(), empty); + EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult()); + + op->getAsyncOperandsMutable().clear(); + op->removeAsyncOperandsDeviceTypeAttr(); +} + +TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) { + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); + testAsyncValueDataEntry(b, context, loc, dtypes); +} + template void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector &dtypes,