Skip to content

Commit

Permalink
[flang] Make code more homogenous in CodeGen
Browse files Browse the repository at this point in the history
This patch just make the code more similar
in each conversion.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D129071
  • Loading branch information
clementval committed Jul 4, 2022
1 parent b37dafd commit 12d26ce
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -1432,12 +1432,11 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
mlir::LogicalResult
matchAndRewrite(fir::EmboxOp embox, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
assert(!embox.getShape() && "There should be no dims on this embox op");
auto [boxTy, dest, eleSize] =
consDescriptorPrefix(embox, rewriter, /*rank=*/0,
/*lenParams=*/adaptor.getOperands().drop_front(1));
dest = insertBaseAddress(rewriter, embox.getLoc(), dest,
adaptor.getOperands()[0]);
auto [boxTy, dest, eleSize] = consDescriptorPrefix(
embox, rewriter, /*rank=*/0, /*lenParams=*/operands.drop_front(1));
dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]);
if (isDerivedTypeWithLenParams(boxTy)) {
TODO(embox.getLoc(),
"fir.embox codegen of derived with length parameters");
Expand All @@ -1456,11 +1455,11 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
mlir::LogicalResult
matchAndRewrite(fir::cg::XEmboxOp xbox, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto [boxTy, dest, eleSize] = consDescriptorPrefix(
xbox, rewriter, xbox.getOutRank(),
adaptor.getOperands().drop_front(xbox.lenParamOffset()));
// Generate the triples in the dims field of the descriptor
mlir::ValueRange operands = adaptor.getOperands();
auto [boxTy, dest, eleSize] =
consDescriptorPrefix(xbox, rewriter, xbox.getOutRank(),
operands.drop_front(xbox.lenParamOffset()));
// Generate the triples in the dims field of the descriptor
auto i64Ty = mlir::IntegerType::get(xbox.getContext(), 64);
mlir::Value base = operands[0];
assert(!xbox.shape().empty() && "must have a shape");
Expand Down Expand Up @@ -1955,11 +1954,12 @@ struct ExtractValueOpConversion
mlir::LogicalResult
doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
auto attrs = collectIndices(rewriter, extractVal.getCoor());
toRowMajor(attrs, adaptor.getOperands()[0].getType());
toRowMajor(attrs, operands[0].getType());
auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
extractVal, ty, adaptor.getOperands()[0], position);
extractVal, ty, operands[0], position);
return mlir::success();
}
};
Expand All @@ -1974,12 +1974,12 @@ struct InsertValueOpConversion
mlir::LogicalResult
doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
auto attrs = collectIndices(rewriter, insertVal.getCoor());
toRowMajor(attrs, adaptor.getOperands()[0].getType());
toRowMajor(attrs, operands[0].getType());
auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
position);
insertVal, ty, operands[0], operands[1], position);
return mlir::success();
}
};
Expand Down Expand Up @@ -2123,8 +2123,7 @@ struct XArrayCoorOpConversion
// that was just computed.
if (baseIsBoxed) {
// Use stride in bytes from the descriptor.
mlir::Value stride =
loadStrideFromBox(loc, adaptor.getOperands()[0], i, rewriter);
mlir::Value stride = loadStrideFromBox(loc, operands[0], i, rewriter);
auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, stride);
offset = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset);
} else {
Expand All @@ -2143,17 +2142,16 @@ struct XArrayCoorOpConversion
if (baseIsBoxed) {
// Working with byte offsets. The base address is read from the fir.box.
// and need to be casted to i8* to do the pointer arithmetic.
mlir::Type baseTy =
getBaseAddrTypeFromBox(adaptor.getOperands()[0].getType());
mlir::Type baseTy = getBaseAddrTypeFromBox(operands[0].getType());
mlir::Value base =
loadBaseAddrFromBox(loc, baseTy, adaptor.getOperands()[0], rewriter);
loadBaseAddrFromBox(loc, baseTy, operands[0], rewriter);
mlir::Type voidPtrTy = getVoidPtrType();
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
llvm::SmallVector<mlir::Value> args{offset};
auto addr =
rewriter.create<mlir::LLVM::GEPOp>(loc, voidPtrTy, base, args);
if (coor.subcomponent().empty()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(coor, baseTy, addr);
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(coor, ty, addr);
return mlir::success();
}
auto casted = rewriter.create<mlir::LLVM::BitcastOp>(loc, baseTy, addr);
Expand All @@ -2168,8 +2166,7 @@ struct XArrayCoorOpConversion
// row-major layout here.
for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i)
args.push_back(operands[i]);
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(coor, baseTy, casted,
args);
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(coor, ty, casted, args);
return mlir::success();
}

Expand All @@ -2195,20 +2192,18 @@ struct XArrayCoorOpConversion
}
}
// Cast the base address to a pointer to T.
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, ty,
adaptor.getOperands()[0]);
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, ty, operands[0]);
} else {
// Operand #0 must have a pointer type. For subcomponent slicing, we
// want to cast away the array type and have a plain struct type.
mlir::Type ty0 = adaptor.getOperands()[0].getType();
mlir::Type ty0 = operands[0].getType();
auto ptrTy = ty0.dyn_cast<mlir::LLVM::LLVMPointerType>();
assert(ptrTy && "expected pointer type");
mlir::Type eleTy = ptrTy.getElementType();
while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
eleTy = arrTy.getElementType();
auto newTy = mlir::LLVM::LLVMPointerType::get(eleTy);
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, newTy,
adaptor.getOperands()[0]);
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, newTy, operands[0]);
}
llvm::SmallVector<mlir::Value> args = {offset};
for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i)
Expand Down

0 comments on commit 12d26ce

Please sign in to comment.