-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[OpenACC][CIR] Implement 'reduction' combiner lowering for 5 ops #162906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -527,16 +527,140 @@ void OpenACCRecipeBuilderBase::createFirstprivateRecipeCopy( | |||||
// doesn't restore it aftewards. | ||||||
void OpenACCRecipeBuilderBase::createReductionRecipeCombiner( | ||||||
mlir::Location loc, mlir::Location locEnd, mlir::Value mainOp, | ||||||
mlir::acc::ReductionRecipeOp recipe, size_t numBounds) { | ||||||
mlir::acc::ReductionRecipeOp recipe, size_t numBounds, QualType origType, | ||||||
llvm::ArrayRef<OpenACCReductionRecipe::CombinerRecipe> combinerRecipes) { | ||||||
mlir::Block *block = | ||||||
createRecipeBlock(recipe.getCombinerRegion(), mainOp.getType(), loc, | ||||||
numBounds, /*isInit=*/false); | ||||||
builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); | ||||||
CIRGenFunction::LexicalScope ls(cgf, loc, block); | ||||||
|
||||||
mlir::BlockArgument lhsArg = block->getArgument(0); | ||||||
mlir::Value lhsArg = block->getArgument(0); | ||||||
mlir::Value rhsArg = block->getArgument(1); | ||||||
llvm::MutableArrayRef<mlir::BlockArgument> boundsRange = | ||||||
block->getArguments().drop_front(2); | ||||||
|
||||||
if (llvm::any_of(combinerRecipes, [](auto &r) { return r.Op == nullptr; })) { | ||||||
cgf.cgm.errorNYI(loc, "OpenACC Reduction combiner not generated"); | ||||||
mlir::acc::YieldOp::create(builder, locEnd, block->getArgument(0)); | ||||||
return; | ||||||
} | ||||||
|
||||||
// apply the bounds so that we can get our bounds emitted correctly. | ||||||
for (mlir::BlockArgument boundArg : llvm::reverse(boundsRange)) | ||||||
std::tie(lhsArg, rhsArg) = | ||||||
createBoundsLoop(lhsArg, rhsArg, boundArg, loc, /*inverse=*/false); | ||||||
|
||||||
// Emitter for when we know this isn't a struct or array we have to loop | ||||||
// through. This should work for the 'field' once the get-element call has | ||||||
// been made. | ||||||
auto emitSingleCombiner = | ||||||
[&](mlir::Value lhsArg, mlir::Value rhsArg, | ||||||
const OpenACCReductionRecipe::CombinerRecipe &combiner) { | ||||||
mlir::Type elementTy = | ||||||
mlir::cast<cir::PointerType>(lhsArg.getType()).getPointee(); | ||||||
CIRGenFunction::DeclMapRevertingRAII declMapRAIILhs{cgf, combiner.LHS}; | ||||||
cgf.setAddrOfLocalVar( | ||||||
combiner.LHS, Address{lhsArg, elementTy, | ||||||
cgf.getContext().getDeclAlign(combiner.LHS)}); | ||||||
CIRGenFunction::DeclMapRevertingRAII declMapRAIIRhs{cgf, combiner.RHS}; | ||||||
cgf.setAddrOfLocalVar( | ||||||
combiner.RHS, Address{rhsArg, elementTy, | ||||||
cgf.getContext().getDeclAlign(combiner.RHS)}); | ||||||
|
||||||
[[maybe_unused]] mlir::LogicalResult stmtRes = | ||||||
cgf.emitStmt(combiner.Op, /*useCurrentScope=*/true); | ||||||
}; | ||||||
|
||||||
// Emitter for when we know this is either a non-array or element of an array | ||||||
// (which also shouldn't be an array type?). This function should generate the | ||||||
// loop to do this on each individual array or struct element (if necessary). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the comment. If we know this is a non-array, what does the part about a loop to do this on each individual array element mean? |
||||||
auto emitCombiner = [&](mlir::Value lhsArg, mlir::Value rhsArg, QualType Ty) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if (const auto *RD = Ty->getAsRecordDecl()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if (combinerRecipes.size() == 1 && | ||||||
cgf.getContext().hasSameType(Ty, combinerRecipes[0].LHS->getType())) { | ||||||
// If this is a 'top level' operator on the type we can just emit this | ||||||
// as a simple one. | ||||||
emitSingleCombiner(lhsArg, rhsArg, combinerRecipes[0]); | ||||||
} else { | ||||||
// else we have to handle each individual field after after a | ||||||
// get-element. | ||||||
for (const auto &[field, combiner] : | ||||||
llvm::zip_equal(RD->fields(), combinerRecipes)) { | ||||||
mlir::Type fieldType = cgf.convertType(field->getType()); | ||||||
auto fieldPtr = cir::PointerType::get(fieldType); | ||||||
|
||||||
mlir::Value lhsField = builder.createGetMember( | ||||||
loc, fieldPtr, lhsArg, field->getName(), field->getFieldIndex()); | ||||||
mlir::Value rhsField = builder.createGetMember( | ||||||
loc, fieldPtr, rhsArg, field->getName(), field->getFieldIndex()); | ||||||
|
||||||
emitSingleCombiner(lhsField, rhsField, combiner); | ||||||
} | ||||||
} | ||||||
|
||||||
} else { | ||||||
// if this is a single-thing (because we should know this isn't an array, | ||||||
// as Sema wouldn't let us get here), we can just do a normal emit call. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe assert that it's not an array? |
||||||
emitSingleCombiner(lhsArg, rhsArg, combinerRecipes[0]); | ||||||
} | ||||||
}; | ||||||
|
||||||
if (const auto *cat = cgf.getContext().getAsConstantArrayType(origType)) { | ||||||
// If we're in an array, we have to emit the combiner for each element of | ||||||
// the array. | ||||||
auto itrTy = mlir::cast<cir::IntType>(cgf.PtrDiffTy); | ||||||
auto itrPtrTy = cir::PointerType::get(itrTy); | ||||||
|
||||||
mlir::Value zero = | ||||||
builder.getConstInt(loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy), 0); | ||||||
mlir::Value itr = | ||||||
cir::AllocaOp::create(builder, loc, itrPtrTy, itrTy, "itr", | ||||||
cgf.cgm.getSize(cgf.getPointerAlign())); | ||||||
builder.CIRBaseBuilderTy::createStore(loc, zero, itr); | ||||||
|
||||||
builder.setInsertionPointAfter(builder.createFor( | ||||||
loc, | ||||||
/*condBuilder=*/ | ||||||
[&](mlir::OpBuilder &b, mlir::Location loc) { | ||||||
auto loadItr = cir::LoadOp::create(builder, loc, {itr}); | ||||||
mlir::Value arraySize = builder.getConstInt( | ||||||
loc, mlir::cast<cir::IntType>(cgf.PtrDiffTy), cat->getZExtSize()); | ||||||
auto cmp = builder.createCompare(loc, cir::CmpOpKind::lt, loadItr, | ||||||
arraySize); | ||||||
builder.createCondition(cmp); | ||||||
}, | ||||||
/*bodyBuilder=*/ | ||||||
[&](mlir::OpBuilder &b, mlir::Location loc) { | ||||||
auto loadItr = cir::LoadOp::create(builder, loc, {itr}); | ||||||
auto lhsElt = builder.getArrayElement( | ||||||
loc, loc, lhsArg, cgf.convertType(cat->getElementType()), loadItr, | ||||||
/*shouldDecay=*/true); | ||||||
auto rhsElt = builder.getArrayElement( | ||||||
loc, loc, rhsArg, cgf.convertType(cat->getElementType()), loadItr, | ||||||
/*shouldDecay=*/true); | ||||||
|
||||||
emitCombiner(lhsElt, rhsElt, cat->getElementType()); | ||||||
builder.createYield(loc); | ||||||
}, | ||||||
/*stepBuilder=*/ | ||||||
[&](mlir::OpBuilder &b, mlir::Location loc) { | ||||||
auto loadItr = cir::LoadOp::create(builder, loc, {itr}); | ||||||
auto inc = cir::UnaryOp::create(builder, loc, loadItr.getType(), | ||||||
cir::UnaryOpKind::Inc, loadItr); | ||||||
builder.CIRBaseBuilderTy::createStore(loc, inc, itr); | ||||||
builder.createYield(loc); | ||||||
})); | ||||||
|
||||||
mlir::acc::YieldOp::create(builder, locEnd, lhsArg); | ||||||
} else if (origType->isArrayType()) { | ||||||
cgf.cgm.errorNYI(loc, | ||||||
"OpenACC Reduction combiner non-constant array recipe"); | ||||||
} else { | ||||||
emitCombiner(lhsArg, rhsArg, origType); | ||||||
} | ||||||
|
||||||
builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); | ||||||
mlir::acc::YieldOp::create(builder, locEnd, block->getArgument(0)); | ||||||
} | ||||||
|
||||||
} // namespace clang::CIRGen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.