diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index e305e2fbde5b1..aaf8feb29401c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -347,6 +347,24 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", }]; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "::mlir::Value":$extent), [{ + build($_builder, $_state, + ::mlir::acc::DataBoundsType::get($_builder.getContext()), + /*lowerbound=*/{}, /*upperbound=*/{}, extent, + /*stride=*/{}, /*strideInBytes=*/nullptr, /*startIdx=*/{}); + }] + >, + OpBuilder<(ins "::mlir::Value":$lowerbound, + "::mlir::Value":$upperbound), [{ + build($_builder, $_state, + ::mlir::acc::DataBoundsType::get($_builder.getContext()), + lowerbound, upperbound, /*extent=*/{}, + /*stride=*/{}, /*strideInBytes=*/nullptr, /*startIdx=*/{}); + }] + > + ]; } // Data entry operation does not refer to OpenACC spec terminology, but to @@ -450,6 +468,33 @@ class OpenACC_DataEntryOp:$bounds), [{ + build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{}, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + }] + >, + OpBuilder<(ins "::mlir::Value":$varPtr, + "bool":$structured, + "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{}, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), + /*name=*/$_builder.getStringAttr(name)); + }] + > + ]; } //===----------------------------------------------------------------------===// @@ -762,23 +807,13 @@ class OpenACC_DataExitOp, - MemWrite]>], - (ins Arg:$accPtr, - Arg:$varPtr)> { - let summary = "Represents acc copyout semantics - reverse of copyin."; - - let extraClassDeclaration = extraClassDeclarationBase # [{ - /// Check if this is a copyout with zero modifier. - bool isCopyoutZero(); - }]; - +class OpenACC_DataExitOpWithVarPtr : + OpenACC_DataExitOp, + MemWrite]>], + (ins Arg:$accPtr, + Arg:$varPtr)> { let assemblyFormat = [{ `accPtr` `(` $accPtr `:` type($accPtr) `)` (`bounds` `(` $bounds^ `)` )? @@ -787,20 +822,42 @@ def OpenACC_CopyoutOp : OpenACC_DataExitOp<"copyout", `to` `varPtr` `(` $varPtr `:` type($varPtr) `)` attr-dict }]; + + let builders = [ + OpBuilder<(ins "::mlir::Value":$accPtr, + "::mlir::Value":$varPtr, + "bool":$structured, + "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + build($_builder, $_state, accPtr, varPtr, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + }] + >, + OpBuilder<(ins "::mlir::Value":$accPtr, + "::mlir::Value":$varPtr, + "bool":$structured, + "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + build($_builder, $_state, accPtr, varPtr, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), + /*name=*/$_builder.getStringAttr(name)); + }] + > + ]; } -//===----------------------------------------------------------------------===// -// 2.7.11 delete clause -//===----------------------------------------------------------------------===// -def OpenACC_DeleteOp : OpenACC_DataExitOp<"delete", - "mlir::acc::DataClause::acc_delete", "", - [MemoryEffects<[MemRead, +class OpenACC_DataExitOpNoVarPtr : + OpenACC_DataExitOp, MemWrite]>], - (ins Arg:$accPtr)> { - let summary = "Represents acc delete semantics - reverse of create."; - - let extraClassDeclaration = extraClassDeclarationBase; - + (ins Arg:$accPtr)> { let assemblyFormat = [{ `accPtr` `(` $accPtr `:` type($accPtr) `)` (`bounds` `(` $bounds^ `)` )? @@ -808,39 +865,71 @@ def OpenACC_DeleteOp : OpenACC_DataExitOp<"delete", type($asyncOperands), $asyncOperandsDeviceType)^ `)`)? attr-dict }]; + + let builders = [ + OpBuilder<(ins "::mlir::Value":$accPtr, + "bool":$structured, + "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + build($_builder, $_state, accPtr, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + }] + >, + OpBuilder<(ins "::mlir::Value":$accPtr, + "bool":$structured, + "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + build($_builder, $_state, accPtr, + bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), + /*name=*/$_builder.getStringAttr(name)); + }] + > + ]; } //===----------------------------------------------------------------------===// -// 2.7.13 detach clause +// 2.7.8 copyout clause //===----------------------------------------------------------------------===// -def OpenACC_DetachOp : OpenACC_DataExitOp<"detach", - "mlir::acc::DataClause::acc_detach", "", - [MemoryEffects<[MemRead, - MemWrite]>], - (ins Arg:$accPtr)> { - let summary = "Represents acc detach semantics - reverse of attach."; +def OpenACC_CopyoutOp : OpenACC_DataExitOpWithVarPtr<"copyout", + "mlir::acc::DataClause::acc_copyout"> { + let summary = "Represents acc copyout semantics - reverse of copyin."; + let extraClassDeclaration = extraClassDeclarationBase # [{ + /// Check if this is a copyout with zero modifier. + bool isCopyoutZero(); + }]; +} + +//===----------------------------------------------------------------------===// +// 2.7.11 delete clause +//===----------------------------------------------------------------------===// +def OpenACC_DeleteOp : OpenACC_DataExitOpNoVarPtr<"delete", + "mlir::acc::DataClause::acc_delete"> { + let summary = "Represents acc delete semantics - reverse of create."; let extraClassDeclaration = extraClassDeclarationBase; +} - let assemblyFormat = [{ - `accPtr` `(` $accPtr `:` type($accPtr) `)` - (`bounds` `(` $bounds^ `)` )? - (`async` `(` custom($asyncOperands, - type($asyncOperands), $asyncOperandsDeviceType)^ `)`)? - attr-dict - }]; +//===----------------------------------------------------------------------===// +// 2.7.13 detach clause +//===----------------------------------------------------------------------===// +def OpenACC_DetachOp : OpenACC_DataExitOpNoVarPtr<"detach", + "mlir::acc::DataClause::acc_detach"> { + let summary = "Represents acc detach semantics - reverse of attach."; + let extraClassDeclaration = extraClassDeclarationBase; } //===----------------------------------------------------------------------===// // 2.14.4 host clause //===----------------------------------------------------------------------===// -def OpenACC_UpdateHostOp : OpenACC_DataExitOp<"update_host", - "mlir::acc::DataClause::acc_update_host", - "- `varPtr`: The address of variable to copy back to.", - [MemoryEffects<[MemRead, - MemWrite]>], - (ins Arg:$accPtr, - Arg:$varPtr)> { +def OpenACC_UpdateHostOp : OpenACC_DataExitOpWithVarPtr<"update_host", + "mlir::acc::DataClause::acc_update_host"> { let summary = "Represents acc update host semantics."; let extraClassDeclaration = extraClassDeclarationBase # [{ /// Check if this is an acc update self. @@ -848,15 +937,6 @@ def OpenACC_UpdateHostOp : OpenACC_DataExitOp<"update_host", return getDataClause() == acc::DataClause::acc_update_self; } }]; - - let assemblyFormat = [{ - `accPtr` `(` $accPtr `:` type($accPtr) `)` - (`bounds` `(` $bounds^ `)` )? - (`async` `(` custom($asyncOperands, - type($asyncOperands), $asyncOperandsDeviceType)^ `)`)? - `to` `varPtr` `(` $varPtr `:` type($varPtr) `)` - attr-dict - }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index 452f39d8cae9f..fbdada9309d32 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -1,4 +1,4 @@ -//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===// +//===- OpenACCOpsTest.cpp - Unit tests for OpenACC ops --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" @@ -23,7 +24,8 @@ using namespace mlir::acc; class OpenACCOpsTest : public ::testing::Test { protected: OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) { - context.loadDialect(); + context.loadDialect(); } MLIRContext context; @@ -436,3 +438,169 @@ TEST_F(OpenACCOpsTest, routineOpTest) { op->removeBindNameDeviceTypeAttr(); op->removeBindNameAttr(); } + +template +void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, + Location loc, DataClause dataClause) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + + OwningOpRef op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + + EXPECT_EQ(op->getVarPtr(), varPtrOp->getResult()); + EXPECT_EQ(op->getType(), memrefTy); + EXPECT_EQ(op->getDataClause(), dataClause); + EXPECT_TRUE(op->getImplicit()); + EXPECT_TRUE(op->getStructured()); + EXPECT_TRUE(op->getBounds().empty()); + EXPECT_FALSE(op->getVarPtrPtr()); + + OwningOpRef op2 = b.create(loc, varPtrOp->getResult(), + /*structured=*/false, /*implicit=*/false); + EXPECT_FALSE(op2->getImplicit()); + EXPECT_FALSE(op2->getStructured()); + + OwningOpRef extent = + b.create(loc, 1); + OwningOpRef bounds = + b.create(loc, extent->getResult()); + OwningOpRef opWithBounds = + b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, bounds->getResult()); + EXPECT_FALSE(opWithBounds->getBounds().empty()); + EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); + + OwningOpRef opWithName = + b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); +} + +TEST_F(OpenACCOpsTest, shortDataEntryOpBuilder) { + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_private); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_firstprivate); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_reduction); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_deviceptr); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_present); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_copyin); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_create); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_no_create); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_attach); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_getdeviceptr); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_update_device); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_use_device); + testShortDataEntryOpBuilders( + b, context, loc, DataClause::acc_declare_device_resident); + testShortDataEntryOpBuilders(b, context, loc, + DataClause::acc_declare_link); + testShortDataEntryOpBuilders(b, context, loc, DataClause::acc_cache); +} + +template +void testShortDataExitOpBuilders(OpBuilder &b, MLIRContext &context, + Location loc, DataClause dataClause) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = b.create( + loc, varPtrOp->getResult(), /*structured=*/true, /*implicit=*/true); + + OwningOpRef op = + b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + + EXPECT_EQ(op->getVarPtr(), varPtrOp->getResult()); + EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + EXPECT_EQ(op->getDataClause(), dataClause); + EXPECT_TRUE(op->getImplicit()); + EXPECT_TRUE(op->getStructured()); + EXPECT_TRUE(op->getBounds().empty()); + + OwningOpRef op2 = + b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + /*structured=*/false, /*implicit=*/false); + EXPECT_FALSE(op2->getImplicit()); + EXPECT_FALSE(op2->getStructured()); + + OwningOpRef extent = + b.create(loc, 1); + OwningOpRef bounds = + b.create(loc, extent->getResult()); + OwningOpRef opWithBounds = + b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, bounds->getResult()); + EXPECT_FALSE(opWithBounds->getBounds().empty()); + EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); + + OwningOpRef opWithName = + b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); +} + +TEST_F(OpenACCOpsTest, shortDataExitOpBuilder) { + testShortDataExitOpBuilders(b, context, loc, + DataClause::acc_copyout); + testShortDataExitOpBuilders(b, context, loc, + DataClause::acc_update_host); +} + +template +void testShortDataExitNoVarPtrOpBuilders(OpBuilder &b, MLIRContext &context, + Location loc, DataClause dataClause) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = b.create( + loc, varPtrOp->getResult(), /*structured=*/true, /*implicit=*/true); + + OwningOpRef op = b.create(loc, accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + + EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + EXPECT_EQ(op->getDataClause(), dataClause); + EXPECT_TRUE(op->getImplicit()); + EXPECT_TRUE(op->getStructured()); + EXPECT_TRUE(op->getBounds().empty()); + + OwningOpRef op2 = b.create(loc, accPtrOp->getResult(), + /*structured=*/false, /*implicit=*/false); + EXPECT_FALSE(op2->getImplicit()); + EXPECT_FALSE(op2->getStructured()); + + OwningOpRef extent = + b.create(loc, 1); + OwningOpRef bounds = + b.create(loc, extent->getResult()); + OwningOpRef opWithBounds = + b.create(loc, accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, bounds->getResult()); + EXPECT_FALSE(opWithBounds->getBounds().empty()); + EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); + + OwningOpRef opWithName = + b.create(loc, accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); +} + +TEST_F(OpenACCOpsTest, shortDataExitOpNoVarPtrBuilder) { + testShortDataExitNoVarPtrOpBuilders(b, context, loc, + DataClause::acc_delete); + testShortDataExitNoVarPtrOpBuilders(b, context, loc, + DataClause::acc_detach); +}