@@ -24,10 +24,10 @@ SPDX-License-Identifier: MIT
2424#include " llvm/IR/Constants.h"
2525#include " llvm/IR/DerivedTypes.h"
2626#include " llvm/IR/Function.h"
27+ #include " llvm/IR/IRBuilder.h"
2728#include " llvm/IR/Instructions.h"
2829#include " llvm/Support/Debug.h"
2930#include < unordered_map>
30- #include " Probe/Assertion.h"
3131
3232#include " llvmWrapper/IR/DerivedTypes.h"
3333#include " llvmWrapper/Support/TypeSize.h"
@@ -922,14 +922,66 @@ static Instruction* simplifyConstIndirectRegion(Instruction* Inst) {
922922 return Inst;
923923}
924924
925- static Value *simplifyRegionWrite (Instruction *Inst) {
926- IGC_ASSERT (GenXIntrinsic::isWrRegion (Inst));
927- Value *NewVal = Inst->getOperand (GenXIntrinsic::GenXRegion::NewValueOperandNum);
925+ // fold bitcast with wrregion:
926+ // ==> %oldval.cast = bitcast(%oldval)
927+ // %2 = bitcast(%1) %3 = wrregion(%oldval.cast, %1, ...)
928+ // %3 = wrregion(%oldval, %2, ...) %2 = bitcast(%3)
929+ // so it can be baled later.
930+ static Value *simplifyBitCastWithRegionWrite (Instruction *WrR,
931+ const DataLayout &DL,
932+ const GenXSubtarget &ST) {
933+ using namespace GenXIntrinsic ::GenXRegion;
934+ IGC_ASSERT (GenXIntrinsic::isWrRegion (WrR));
935+ Value *NewVal = WrR->getOperand (NewValueOperandNum);
936+ auto *BCI = dyn_cast<BitCastInst>(NewVal);
937+ if (!BCI)
938+ return nullptr ;
939+ if (WrR->hasOneUse () && GenXIntrinsic::isWritePredefReg (WrR->user_back ()))
940+ return nullptr ;
941+ auto *NewScalarTy = BCI->getSrcTy ()->getScalarType ();
942+ // Do not change register category to predicate.
943+ if (NewScalarTy->isIntegerTy (1 ))
944+ return nullptr ;
945+ auto *OldVal = WrR->getOperand (OldValueOperandNum);
946+ if (GenXIntrinsic::isReadWritePredefReg (OldVal))
947+ return nullptr ;
948+ auto *NewVecTy = genx::changeVectorType (OldVal->getType (), NewScalarTy);
949+ if (!NewVecTy)
950+ return nullptr ;
951+ Region R = makeRegionFromBaleInfo (WrR, BaleInfo ());
952+ if (!R.changeElementType (NewScalarTy, &DL))
953+ return nullptr ;
954+ // Transformation is not profitable for 2D regions or if it will require
955+ // legalization.
956+ if (R.is2D () || R.NumElements > llvm::PowerOf2Floor (
957+ genx::getExecSizeAllowedBits (WrR, &ST)))
958+ return nullptr ;
959+ IRBuilder<> IRB (WrR);
960+ auto *OldValCast =
961+ IRB.CreateBitCast (OldVal, NewVecTy, OldVal->getName () + " .cast" );
962+ auto *NewWrR = R.createWrRegion (OldValCast, BCI->getOperand (0 ),
963+ WrR->getName (), WrR, WrR->getDebugLoc ());
964+ auto *NewBCI = IRB.CreateBitCast (NewWrR, WrR->getType (), BCI->getName ());
965+ return NewBCI;
966+ }
928967
968+ static Value *simplifyRegionWrite (Instruction *WrR) {
969+ using namespace GenXIntrinsic ::GenXRegion;
970+ IGC_ASSERT (GenXIntrinsic::isWrRegion (WrR));
971+ Value *NewVal = WrR->getOperand (NewValueOperandNum);
972+
973+ // Replace C with B if R - whole region
974+ // C = wrregion(A, B, R)
975+ if (std::none_of (
976+ WrR->user_begin (), WrR->user_end (),
977+ [](auto *U) { return GenXIntrinsic::isWritePredefReg (U); }) &&
978+ makeRegionFromBaleInfo (WrR, BaleInfo ()).isWhole (WrR->getType ()) &&
979+ NewVal->getType () == WrR->getType ())
980+ return NewVal;
929981 // Replace C with A
930982 // C = wrregion(A, undef, R)
931983 if (isa<UndefValue>(NewVal))
932- return Inst ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
984+ return WrR ->getOperand (OldValueOperandNum);
933985
934986 // When A and undef have the same type, replace C with A
935987 // B = rdregion(A, R)
@@ -941,29 +993,68 @@ static Value *simplifyRegionWrite(Instruction *Inst) {
941993 // C = wrregion(A, B, R)
942994 //
943995 if (GenXIntrinsic::isRdRegion (NewVal)) {
944- Instruction *B = cast<Instruction>(NewVal);
945- Region InnerR = makeRegionFromBaleInfo (B , BaleInfo ());
946- Region OuterR = makeRegionFromBaleInfo (Inst , BaleInfo ());
996+ Instruction *RdR = cast<Instruction>(NewVal);
997+ Region InnerR = makeRegionFromBaleInfo (RdR , BaleInfo ());
998+ Region OuterR = makeRegionFromBaleInfo (WrR , BaleInfo ());
947999 if (OuterR != InnerR)
9481000 return nullptr ;
9491001
950- auto OldValB = B ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
951- if (GenXIntrinsic::isReadPredefReg (OldValB ))
1002+ auto OldValRdR = RdR ->getOperand (OldValueOperandNum);
1003+ if (GenXIntrinsic::isReadPredefReg (OldValRdR ))
9521004 return nullptr ;
953- auto OldValC = Inst ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
954- if ((isa<UndefValue>(OldValC ) &&
955- OldValB ->getType () == OldValC ->getType ()) ||
956- OldValB == OldValC )
957- return OldValB ;
1005+ auto OldValWrR = WrR ->getOperand (OldValueOperandNum);
1006+ if ((isa<UndefValue>(OldValWrR ) &&
1007+ OldValRdR ->getType () == OldValWrR ->getType ()) ||
1008+ OldValRdR == OldValWrR )
1009+ return OldValRdR ;
9581010 }
959-
9601011 return nullptr ;
9611012}
9621013
1014+ // fold bitcast with rdregion:
1015+ // %2 = rdregion(%1, ...) ==> %3 = bitcast(%1)
1016+ // %3 = bitcast(%2) %2 = rdregion(%3, ...)
1017+ // so it can be baled later.
1018+ static Value *simplifyBitCastFromRegionRead (BitCastInst *BCI,
1019+ const DataLayout &DL,
1020+ const GenXSubtarget &ST) {
1021+ using namespace GenXIntrinsic ::GenXRegion;
1022+ Instruction *RdR = dyn_cast<Instruction>(BCI->getOperand (0 ));
1023+ if (!RdR || !GenXIntrinsic::isRdRegion (RdR) || !RdR->hasOneUse ())
1024+ return nullptr ;
1025+ auto *OldVal = RdR->getOperand (OldValueOperandNum);
1026+ if (GenXIntrinsic::isReadPredefReg (OldVal))
1027+ return nullptr ;
1028+ auto *NewScalarTy = BCI->getDestTy ()->getScalarType ();
1029+ // Do not change register category to predicate.
1030+ if (NewScalarTy->isIntegerTy (1 ))
1031+ return nullptr ;
1032+ auto *NewVecTy = genx::changeVectorType (OldVal->getType (), NewScalarTy);
1033+ if (!NewVecTy)
1034+ return nullptr ;
1035+ Region R = makeRegionFromBaleInfo (RdR, BaleInfo ());
1036+ if (!R.changeElementType (NewScalarTy, &DL))
1037+ return nullptr ;
1038+ // Transformation is not profitable for 2D regions or if it will require
1039+ // legalization.
1040+ if (R.is2D () || R.NumElements > llvm::PowerOf2Floor (
1041+ genx::getExecSizeAllowedBits (RdR, &ST)))
1042+ return nullptr ;
1043+ auto *NewBCI =
1044+ IRBuilder<>(BCI).CreateBitCast (OldVal, NewVecTy, BCI->getName ());
1045+ auto *NewRdR =
1046+ R.createRdRegion (NewBCI, RdR->getName (), BCI, RdR->getDebugLoc ());
1047+ return NewRdR;
1048+ }
1049+
9631050static Value *simplifyRegionRead (Instruction *Inst) {
9641051 IGC_ASSERT (GenXIntrinsic::isRdRegion (Inst));
9651052 Value *Input = Inst->getOperand (GenXIntrinsic::GenXRegion::OldValueOperandNum);
966- if (isa<UndefValue>(Input))
1053+ if (!GenXIntrinsic::isReadPredefReg (Input) &&
1054+ makeRegionFromBaleInfo (Inst, BaleInfo ()).isWhole (Input->getType ()) &&
1055+ Input->getType () == Inst->getType ())
1056+ return Input;
1057+ else if (isa<UndefValue>(Input))
9671058 return UndefValue::get (Inst->getType ());
9681059 else if (auto C = dyn_cast<Constant>(Input)) {
9691060 if (auto Splat = C->getSplatValue ()) {
@@ -990,7 +1081,8 @@ static Value *simplifyRegionRead(Instruction *Inst) {
9901081}
9911082
9921083// Simplify a region read or write.
993- Value *llvm::genx::simplifyRegionInst (Instruction *Inst, const DataLayout *DL) {
1084+ Value *llvm::genx::simplifyRegionInst (Instruction *Inst, const DataLayout *DL,
1085+ const GenXSubtarget *ST) {
9941086 if (Inst->use_empty ())
9951087 return nullptr ;
9961088
@@ -1013,11 +1105,17 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
10131105 if (Constant *C = ConstantFoldGenX (Inst, *DL))
10141106 return C;
10151107
1108+ if (auto *BCI = dyn_cast<BitCastInst>(Inst); BCI && DL && ST)
1109+ return simplifyBitCastFromRegionRead (BCI, *DL, *ST);
10161110 ID = GenXIntrinsic::getGenXIntrinsicID (Inst);
10171111 switch (ID) {
10181112 case GenXIntrinsic::genx_wrregionf:
10191113 case GenXIntrinsic::genx_wrregioni:
1020- return simplifyRegionWrite (Inst);
1114+ if (auto *Res = simplifyRegionWrite (Inst))
1115+ return Res;
1116+ if (DL && ST)
1117+ return simplifyBitCastWithRegionWrite (Inst, *DL, *ST);
1118+ break ;
10211119 case GenXIntrinsic::genx_rdregionf:
10221120 case GenXIntrinsic::genx_rdregioni:
10231121 return simplifyRegionRead (Inst);
@@ -1027,12 +1125,13 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
10271125 return nullptr ;
10281126}
10291127
1030- bool llvm::genx::simplifyRegionInsts (Function *F, const DataLayout *DL) {
1128+ bool llvm::genx::simplifyRegionInsts (Function *F, const DataLayout *DL,
1129+ const GenXSubtarget *ST) {
10311130 bool Changed = false ;
10321131 for (auto &BB : F->getBasicBlockList ()) {
10331132 for (auto I = BB.begin (); I != BB.end ();) {
10341133 Instruction *Inst = &*I++;
1035- if (auto V = simplifyRegionInst (Inst, DL)) {
1134+ if (auto V = simplifyRegionInst (Inst, DL, ST )) {
10361135 Inst->replaceAllUsesWith (V);
10371136 Inst->eraseFromParent ();
10381137 Changed = true ;
0 commit comments