From 2499cef2b445c9a4772b4376bcf77bae69f6a20e Mon Sep 17 00:00:00 2001 From: Madhur Amilkanthwar Date: Fri, 15 Aug 2025 00:34:49 -0700 Subject: [PATCH] [GVN] Support rnflow pattern matching and transform --- llvm/include/llvm/Transforms/Scalar/GVN.h | 4 + llvm/lib/Transforms/Scalar/GVN.cpp | 122 ++++++++++++++++++ .../test/Transforms/GVN/PRE/rnflow-gvn-pre.ll | 59 +++++++++ 3 files changed, 185 insertions(+) create mode 100644 llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h index 74a4d6ce00fcc..a73d17b0680de 100644 --- a/llvm/include/llvm/Transforms/Scalar/GVN.h +++ b/llvm/include/llvm/Transforms/Scalar/GVN.h @@ -22,6 +22,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Compiler.h" @@ -44,6 +45,7 @@ class FunctionPass; class GetElementPtrInst; class ImplicitControlFlowTracking; class LoadInst; +class SelectInst; class LoopInfo; class MemDepResult; class MemoryAccess; @@ -409,6 +411,8 @@ class GVNPass : public PassInfoMixin { void addDeadBlock(BasicBlock *BB); void assignValNumForDeadCode(); void assignBlockRPONumber(Function &F); + + bool optimizeMinMaxFindingSelectPattern(SelectInst *Select); }; /// Create a legacy GVN pass. diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index b9b5b5823d780..76653e1a01eae 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2818,6 +2818,10 @@ bool GVNPass::processInstruction(Instruction *I) { } return Changed; } + if (SelectInst *Select = dyn_cast(I)) { + if (optimizeMinMaxFindingSelectPattern(Select)) + return true; + } // Instructions with void type don't return a value, so there's // no point in trying to find redundancies in them. @@ -3410,6 +3414,124 @@ void GVNPass::assignValNumForDeadCode() { } } +bool GVNPass::optimizeMinMaxFindingSelectPattern(SelectInst *Select) { + LLVM_DEBUG( + dbgs() + << "GVN: Analyzing select instruction for minimum finding pattern\n"); + LLVM_DEBUG(dbgs() << "GVN: Select: " << *Select << "\n"); + Value *Condition = Select->getCondition(); + CmpInst *Comparison = dyn_cast(Condition); + if (!Comparison) { + LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison\n"); + return false; + } + + // Check if this is ULT comparison. + CmpInst::Predicate Pred = Comparison->getPredicate(); + if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT && + Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) { + LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred + << "\n"); + return false; + } + + // Check that both operands are loads. + Value *LHS = Comparison->getOperand(0); + Value *RHS = Comparison->getOperand(1); + if (!isa(LHS) || !isa(RHS)) { + LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: " + << Select->getParent()->getName() << "\n"); + + // Transform the pattern. + // Hoist the chain of operations for the second load to preheader. + // Get predecessor of the block containing the select instruction. + BasicBlock *BB = Select->getParent(); + + // Get preheader of the loop. + Loop *L = LI->getLoopFor(BB); + if (!L) { + LLVM_DEBUG(dbgs() << "GVN: Could not find loop\n"); + return false; + } + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) { + LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader\n"); + return false; + } + + // Hoist the chain of operations for the second load to preheader. + // %90 = sext i32 %.05.i to i64 + // %91 = getelementptr float, ptr %0, i64 %90 ; %0 + (sext i32 %85 to i64)*4 + // %92 = getelementptr i8, ptr %91, i64 -4 ; %0 + (sext i32 %85 to i64)*4 - 4 + // %93 = load float, ptr %92, align 4 + + Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr; + IRBuilder<> Builder(Preheader->getTerminator()); + if (match(RHS, + m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_SExt(m_Value(IndexVal))), + m_Value(OffsetVal))))) { + LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << *RHS << "\n"); + LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << "\n"); + + PHINode *Phi = dyn_cast(IndexVal); + if (!Phi) { + LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n"); + return false; + } + Value *InitialMinIndex = Phi->getIncomingValueForBlock(Preheader); + + // Insert PHI node at the top of this block. + PHINode *KnownMinPhi = + PHINode::Create(Builder.getFloatTy(), 2, "known_min", BB->begin()); + + // Build the GEP chain in the preheader. + // 1. hoist_0 = sext i32 to i64 + Value *HoistedSExt = + Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext"); + + // 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt + Value *HoistedGEP1 = Builder.CreateGEP(Builder.getFloatTy(), BasePtr, + HoistedSExt, "hoist_gep1"); + + // 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal + Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1, + OffsetVal, "hoist_gep2"); + + // 4. hoisted_load = load float, ptr HoistedGEP2 + LoadInst *NewLoad = + Builder.CreateLoad(Builder.getFloatTy(), HoistedGEP2, "hoisted_load"); + + // Replace all uses of load with new load. + RHS->replaceAllUsesWith(NewLoad); + dyn_cast(RHS)->eraseFromParent(); + + // Replace second operand of comparison with KnownMinPhi. + Comparison->setOperand(1, KnownMinPhi); + + // Create new select instruction for selecting the minimum value. + IRBuilder<> SelectBuilder(BB->getTerminator()); + SelectInst *CurrentMinSelect = + dyn_cast(SelectBuilder.CreateSelect( + Comparison, LHS, KnownMinPhi, "current_min")); + + // Populate PHI node. + KnownMinPhi->addIncoming(NewLoad, Preheader); + KnownMinPhi->addIncoming(CurrentMinSelect, BB); + std::cout << "Transformed the code\n"; + return true; + } else { + LLVM_DEBUG(dbgs() << "GVN: Could not find pattern: " << *RHS << "\n"); + std::cout << "GVN: Could not find pattern: " << "\n"; + return false; + } + return false; +} + + class llvm::gvn::GVNLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid. diff --git a/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll b/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll new file mode 100644 index 0000000000000..6f17d4ab30240 --- /dev/null +++ b/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll @@ -0,0 +1,59 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; Minimal test case containing only the .lr.ph.i basic block +; RUN: opt -passes=gvn -S < %s | FileCheck %s + +define void @test_lr_ph_i(ptr %0) { +; CHECK-LABEL: define void @test_lr_ph_i( +; CHECK-SAME: ptr [[TMP0:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: [[HOIST_GEP1:%.*]] = getelementptr float, ptr [[TMP0]], i64 1 +; CHECK-NEXT: [[HOIST_GEP2:%.*]] = getelementptr i8, ptr [[HOIST_GEP1]], i64 -4 +; CHECK-NEXT: [[HOISTED_LOAD:%.*]] = load float, ptr [[HOIST_GEP2]], align 4 +; CHECK-NEXT: br label %[[DOTLR_PH_I:.*]] +; CHECK: [[_LR_PH_I:.*:]] +; CHECK-NEXT: [[KNOWN_MIN:%.*]] = phi float [ [[HOISTED_LOAD]], %[[ENTRY]] ], [ [[CURRENT_MIN:%.*]], %[[DOTLR_PH_I]] ] +; CHECK-NEXT: [[INDVARS_IV_I:%.*]] = phi i64 [ 1, %[[ENTRY]] ], [ [[INDVARS_IV_NEXT_I:%.*]], %[[DOTLR_PH_I]] ] +; CHECK-NEXT: [[TMP1:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[TMP10:%.*]], %[[DOTLR_PH_I]] ] +; CHECK-NEXT: [[DOT05_I:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[DOT1_I:%.*]], %[[DOTLR_PH_I]] ] +; CHECK-NEXT: [[INDVARS_IV_NEXT_I]] = add nsw i64 [[INDVARS_IV_I]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[INDVARS_IV_I]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 -8 +; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[TMP3]], align 4 +; CHECK-NEXT: [[TMP5:%.*]] = sext i32 [[DOT05_I]] to i64 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i64 -4 +; CHECK-NEXT: [[TMP8:%.*]] = fcmp contract olt float [[TMP4]], [[KNOWN_MIN]] +; CHECK-NEXT: [[TMP9:%.*]] = trunc nsw i64 [[INDVARS_IV_NEXT_I]] to i32 +; CHECK-NEXT: [[DOT1_I]] = select i1 [[TMP8]], i32 [[TMP9]], i32 [[DOT05_I]] +; CHECK-NEXT: [[TMP10]] = add nsw i64 [[TMP1]], -1 +; CHECK-NEXT: [[TMP11:%.*]] = icmp samesign ugt i64 [[TMP1]], 1 +; CHECK-NEXT: [[CURRENT_MIN]] = select i1 [[TMP8]], float [[TMP4]], float [[KNOWN_MIN]] +; CHECK-NEXT: br i1 [[TMP11]], label %[[DOTLR_PH_I]], label %[[EXIT:.*]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret void +; +entry: + br label %.lr.ph.i + +.lr.ph.i: ; preds = %.lr.ph.i, %entry + %indvars.iv.i = phi i64 [ 1, %entry ], [ %indvars.iv.next.i, %.lr.ph.i ] + %86 = phi i64 [ 0, %entry ], [ %96, %.lr.ph.i ] + %.05.i = phi i32 [ 1, %entry ], [ %.1.i, %.lr.ph.i ] + %indvars.iv.next.i = add nsw i64 %indvars.iv.i, -1 + %87 = getelementptr float, ptr %0, i64 %indvars.iv.i + %88 = getelementptr i8, ptr %87, i64 -8 ; first load : %0 + 4 * 1 - 8 + %89 = load float, ptr %88, align 4 + %90 = sext i32 %.05.i to i64 + %91 = getelementptr float, ptr %0, i64 %90 ; %0 + 4 * 1 + %92 = getelementptr i8, ptr %91, i64 -4 ; second load : %0 + 4 * 1 - 4 + %93 = load float, ptr %92, align 4 + %94 = fcmp contract olt float %89, %93 + %95 = trunc nsw i64 %indvars.iv.next.i to i32 + %.1.i = select i1 %94, i32 %95, i32 %.05.i + %96 = add nsw i64 %86, -1 + %97 = icmp samesign ugt i64 %86, 1 + br i1 %97, label %.lr.ph.i, label %exit + +exit: + ret void +}