Skip to content

Conversation

terapines-osc-mlir
Copy link
Contributor

@terapines-osc-mlir terapines-osc-mlir commented Sep 12, 2025

This PR adds an initial pass to convert eligible scf.if, scf.for and scf.parallel into affine.if, affine.for and affine.parallel. The Affine dialect enables stronger optimization and analysis capabilities compared to SCF, including dependence analysis, loop tiling/fusion, and parallelization. Promoting analyzable SCF loops to Affine provides opportunities for polyhedral-style optimizations at the core dialect level.

Related RFC: https://discourse.llvm.org/t/rfc-add-scf-to-affine-conversion-pass-in-mlir/88036

@llvmbot llvmbot added the mlir label Sep 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir

Author: Terapines MLIR (terapines-osc-mlir)

Changes

Based on Polygeist commit 77c04bb2a7a2406ca9480bcc9e729b07d2c8d077


Patch is 104.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158267.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+4)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+18)
  • (added) mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h (+14)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp (+1559)
  • (added) mlir/lib/Conversion/SCFToAffine/CMakeLists.txt (+27)
  • (added) mlir/lib/Conversion/SCFToAffine/Ops.cpp (+359)
  • (added) mlir/lib/Conversion/SCFToAffine/Ops.h (+114)
  • (added) mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp (+296)
  • (added) mlir/test/Conversion/SCFToAffine/affinecfg.mlir (+166)
  • (added) mlir/test/Conversion/SCFToAffine/affraise.mlir (+48)
  • (added) mlir/test/Conversion/SCFToAffine/affraise2.mlir (+31)
  • (added) mlir/test/Conversion/SCFToAffine/affraise3.mlir (+95)
  • (added) mlir/test/Conversion/SCFToAffine/raisescffor.mlir (+25)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..fa4bcb5bce5db 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -58,6 +58,7 @@
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
@@ -87,6 +88,9 @@
 
 namespace mlir {
 
+std::unique_ptr<Pass> replaceAffineCFGPass();
+std::unique_ptr<Pass> createRaiseSCFToAffinePass();
+
 /// Generate the code for registering conversion passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Conversion/Passes.h.inc"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..85f49448e38da 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1029,6 +1029,24 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SCFToAffine
+//===----------------------------------------------------------------------===//
+def AffineCFG : Pass<"affine-cfg"> {
+  let summary = "Replace scf.if and similar with affine.if";
+  let constructor = "mlir::replaceAffineCFGPass()";
+}
+
+def RaiseSCFToAffine : Pass<"raise-scf-to-affine"> {
+  let summary = "Raise SCF to affine";
+  let constructor = "mlir::createRaiseSCFToAffinePass()";
+  let dependentDialects = [
+    "affine::AffineDialect",
+    "scf::SCFDialect",
+  ];
+}
+
+
 //===----------------------------------------------------------------------===//
 // SCFToControlFlow
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
new file mode 100644
index 0000000000000..372d19d60fdb3
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
@@ -0,0 +1,14 @@
+#ifndef __MLIR_CONVERSION_SCFTOAFFINE_H
+#define __MLIR_CONVERSION_SCFTOAFFINE_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
+#define GEN_PASS_DECL_AFFINECFGPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // __MLIR_CONVERSION_SCFTOAFFINE_H
\ No newline at end of file
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..d9da085378834 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -52,6 +52,7 @@ add_subdirectory(OpenMPToLLVM)
 add_subdirectory(PDLToPDLInterp)
 add_subdirectory(PtrToLLVM)
 add_subdirectory(ReconcileUnrealizedCasts)
+add_subdirectory(SCFToAffine)
 add_subdirectory(SCFToControlFlow)
 add_subdirectory(SCFToEmitC)
 add_subdirectory(SCFToGPU)
diff --git a/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp
new file mode 100644
index 0000000000000..69e46749c01e1
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp
@@ -0,0 +1,1559 @@
+#include "./Ops.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+#include <deque>
+
+#define DEBUG_TYPE "affine-cfg"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::affine;
+
+namespace mlir {
+
+#define GEN_PASS_DEF_AFFINECFG
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+bool isValidIndex(Value val);
+
+bool isReadOnly(Operation *op);
+
+bool isValidSymbolInt(Value value, bool recur = true);
+bool isValidSymbolInt(Operation *defOp, bool recur) {
+  Attribute operandCst;
+  if (matchPattern(defOp, m_Constant(&operandCst)))
+    return true;
+
+  if (recur) {
+    if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
+            RemUIOp, SubIOp, CmpIOp, TruncIOp, ExtUIOp, ExtSIOp>(defOp))
+      if (llvm::all_of(defOp->getOperands(),
+                       [&](Value v) { return isValidSymbolInt(v, recur); }))
+        return true;
+    if (auto ifOp = mlir::dyn_cast<scf::IfOp>(defOp)) {
+      if (isValidSymbolInt(ifOp.getCondition(), recur)) {
+        if (llvm::all_of(
+                ifOp.thenBlock()->without_terminator(),
+                [&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
+            llvm::all_of(
+                ifOp.elseBlock()->without_terminator(),
+                [&](Operation &o) { return isValidSymbolInt(&o, recur); }))
+          return true;
+      }
+    }
+    if (auto ifOp = dyn_cast<affine::AffineIfOp>(defOp)) {
+      if (llvm::all_of(ifOp.getOperands(),
+                       [&](Value o) { return isValidSymbolInt(o, recur); }))
+        if (llvm::all_of(
+                ifOp.getThenBlock()->without_terminator(),
+                [&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
+            llvm::all_of(
+                ifOp.getElseBlock()->without_terminator(),
+                [&](Operation &o) { return isValidSymbolInt(&o, recur); }))
+          return true;
+    }
+  }
+  return false;
+}
+
+// isValidSymbol, even if not index
+bool isValidSymbolInt(Value value, bool recur) {
+  // Check that the value is a top level value.
+  if (affine::isTopLevelValue(value))
+    return true;
+
+  if (auto *defOp = value.getDefiningOp()) {
+    if (isValidSymbolInt(defOp, recur))
+      return true;
+    return affine::isValidSymbol(value, affine::getAffineScope(defOp));
+  }
+
+  return false;
+}
+
+struct AffineApplyNormalizer {
+  AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands,
+                        PatternRewriter &rewriter, DominanceInfo &DI);
+
+  /// Returns the AffineMap resulting from normalization.
+  AffineMap getAffineMap() { return affineMap; }
+
+  SmallVector<Value, 8> getOperands() {
+    SmallVector<Value, 8> res(reorderedDims);
+    res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
+    return res;
+  }
+
+private:
+  /// Helper function to insert `v` into the coordinate system of the current
+  /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
+  /// renumbered position.
+  AffineDimExpr renumberOneDim(Value v);
+
+  /// Maps of Value to position in `affineMap`.
+  DenseMap<Value, unsigned> dimValueToPosition;
+
+  /// Ordered dims and symbols matching positional dims and symbols in
+  /// `affineMap`.
+  SmallVector<Value, 8> reorderedDims;
+  SmallVector<Value, 8> concatenatedSymbols;
+
+  AffineMap affineMap;
+};
+
+static bool isAffineForArg(Value val) {
+  if (!mlir::isa<BlockArgument>(val))
+    return false;
+  Operation *parentOp =
+      mlir::cast<BlockArgument>(val).getOwner()->getParentOp();
+  return (
+      isa_and_nonnull<affine::AffineForOp, affine::AffineParallelOp>(parentOp));
+}
+
+static bool legalCondition(Value en, bool dim = false) {
+  if (en.getDefiningOp<affine::AffineApplyOp>())
+    return true;
+
+  if (!dim && !isValidSymbolInt(en, /*recur*/ false)) {
+    if (isValidIndex(en) || isValidSymbolInt(en, /*recur*/ true)) {
+      return true;
+    }
+  }
+
+  while (auto ic = en.getDefiningOp<IndexCastOp>())
+    en = ic.getIn();
+
+  if ((en.getDefiningOp<AddIOp>() || en.getDefiningOp<SubIOp>() ||
+       en.getDefiningOp<MulIOp>() || en.getDefiningOp<RemUIOp>() ||
+       en.getDefiningOp<RemSIOp>()) &&
+      (en.getDefiningOp()->getOperand(1).getDefiningOp<ConstantIntOp>() ||
+       en.getDefiningOp()->getOperand(1).getDefiningOp<ConstantIndexOp>()))
+    return true;
+  // if (auto IC = dyn_cast_or_null<IndexCastOp>(en.getDefiningOp())) {
+  //	if (!outer || legalCondition(IC.getOperand(), false)) return true;
+  //}
+  if (!dim)
+    if (auto BA = dyn_cast<BlockArgument>(en)) {
+      if (isa<affine::AffineForOp, affine::AffineParallelOp>(
+              BA.getOwner()->getParentOp()))
+        return true;
+    }
+  return false;
+}
+
+/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
+/// keep a correspondence between the mathematical `map` and the `operands` of
+/// a given affine::AffineApplyOp. This correspondence is maintained by
+/// iterating over the operands and forming an `auxiliaryMap` that can be
+/// composed mathematically with `map`. To keep this correspondence in cases
+/// where symbols are produced by affine.apply operations, we perform a local
+/// rewrite of symbols as dims.
+///
+/// Rationale for locally rewriting symbols as dims:
+/// ================================================
+/// The mathematical composition of AffineMap must always concatenate symbols
+/// because it does not have enough information to do otherwise. For example,
+/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
+/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
+///
+/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
+/// applied to the same mlir::Value for both s0 and s1.
+/// As a consequence mathematical composition of AffineMap always concatenates
+/// symbols.
+///
+/// When AffineMaps are used in affine::AffineApplyOp however, they may specify
+/// composition via symbols, which is ambiguous mathematically. This corner case
+/// is handled by locally rewriting such symbols that come from
+/// affine::AffineApplyOp into dims and composing through dims.
+/// TODO: Composition via symbols comes at a significant code
+/// complexity. Alternatively we should investigate whether we want to
+/// explicitly disallow symbols coming from affine.apply and instead force the
+/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
+/// extra API calls for such uses, which haven't popped up until now) and the
+/// benefit potentially big: simpler and more maintainable code for a
+/// non-trivial, recursive, procedure.
+AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
+                                             ArrayRef<Value> operands,
+                                             PatternRewriter &rewriter,
+                                             DominanceInfo &DI) {
+  assert(map.getNumInputs() == operands.size() &&
+         "number of operands does not match the number of map inputs");
+
+  LLVM_DEBUG(map.print(llvm::dbgs() << "\nInput map: "));
+
+  SmallVector<Value, 8> addedValues;
+
+  llvm::SmallSet<unsigned, 1> symbolsToPromote;
+
+  unsigned numDims = map.getNumDims();
+  unsigned numSymbols = map.getNumSymbols();
+
+  SmallVector<AffineExpr, 8> dimReplacements;
+  SmallVector<AffineExpr, 8> symReplacements;
+
+  SmallVector<SmallVectorImpl<Value> *> opsTodos;
+  auto replaceOp = [&](Operation *oldOp, Operation *newOp) {
+    for (auto [oldV, newV] :
+         llvm::zip(oldOp->getResults(), newOp->getResults()))
+      for (auto *ops : opsTodos)
+        for (auto &op : *ops)
+          if (op == oldV)
+            op = newV;
+  };
+
+  std::function<Value(Value, bool)> fix = [&](Value v,
+                                              bool index) -> Value /*legal*/ {
+    if (isValidSymbolInt(v, /*recur*/ false))
+      return v;
+    if (index && isAffineForArg(v))
+      return v;
+    auto *op = v.getDefiningOp();
+    if (!op)
+      return nullptr;
+    if (!op)
+      llvm::errs() << v << "\n";
+    assert(op);
+    if (isa<ConstantOp>(op) || isa<ConstantIndexOp>(op))
+      return v;
+    if (!isReadOnly(op)) {
+      return nullptr;
+    }
+    Operation *front = nullptr;
+    SmallVector<Value> ops;
+    opsTodos.push_back(&ops);
+    std::function<void(Operation *)> getAllOps = [&](Operation *todo) {
+      for (auto v : todo->getOperands()) {
+        if (llvm::all_of(op->getRegions(), [&](Region &r) {
+              return !r.isAncestor(v.getParentRegion());
+            }))
+          ops.push_back(v);
+      }
+      for (auto &r : todo->getRegions()) {
+        for (auto &b : r.getBlocks())
+          for (auto &o2 : b.without_terminator())
+            getAllOps(&o2);
+      }
+    };
+    getAllOps(op);
+    for (auto o : ops) {
+      Operation *next;
+      if (auto *op = o.getDefiningOp()) {
+        if (Value nv = fix(o, index)) {
+          op = nv.getDefiningOp();
+        } else {
+          return nullptr;
+        }
+        next = op->getNextNode();
+      } else {
+        auto ba = mlir::cast<BlockArgument>(o);
+        if (index && isAffineForArg(ba)) {
+        } else if (!isValidSymbolInt(o, /*recur*/ false)) {
+          return nullptr;
+        }
+        next = &ba.getOwner()->front();
+      }
+      if (front == nullptr)
+        front = next;
+      else if (DI.dominates(front, next))
+        front = next;
+    }
+    opsTodos.pop_back();
+    if (!front)
+      op->dump();
+    assert(front);
+    PatternRewriter::InsertionGuard B(rewriter);
+    rewriter.setInsertionPoint(front);
+    auto *cloned = rewriter.clone(*op);
+    replaceOp(op, cloned);
+    rewriter.replaceOp(op, cloned->getResults());
+    return cloned->getResult(0);
+  };
+  auto renumberOneSymbol = [&](Value v) {
+    for (auto i : llvm::enumerate(addedValues)) {
+      if (i.value() == v)
+        return getAffineSymbolExpr(i.index(), map.getContext());
+    }
+    auto expr = getAffineSymbolExpr(addedValues.size(), map.getContext());
+    addedValues.push_back(v);
+    return expr;
+  };
+
+  // 2. Compose affine::AffineApplyOps and dispatch dims or symbols.
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    auto t = operands[i];
+    auto decast = t;
+    while (true) {
+      if (auto idx = decast.getDefiningOp<IndexCastOp>()) {
+        decast = idx.getIn();
+        continue;
+      }
+      if (auto idx = decast.getDefiningOp<ExtUIOp>()) {
+        decast = idx.getIn();
+        continue;
+      }
+      if (auto idx = decast.getDefiningOp<ExtSIOp>()) {
+        decast = idx.getIn();
+        continue;
+      }
+      break;
+    }
+
+    if (!isValidSymbolInt(t, /*recur*/ false)) {
+      t = decast;
+    }
+
+    // Only promote one at a time, lest we end up with two dimensions
+    // multiplying each other.
+
+    if (((!isValidSymbolInt(t, /*recur*/ false) &&
+          (t.getDefiningOp<AddIOp>() || t.getDefiningOp<SubIOp>() ||
+           (t.getDefiningOp<MulIOp>() &&
+            ((isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+              isValidSymbolInt(t.getDefiningOp()->getOperand(1))) ||
+             (isValidIndex(t.getDefiningOp()->getOperand(1)) &&
+              isValidSymbolInt(t.getDefiningOp()->getOperand(0)))) &&
+            !(fix(t.getDefiningOp()->getOperand(0), false) &&
+              fix(t.getDefiningOp()->getOperand(1), false))
+
+                ) ||
+           ((t.getDefiningOp<DivUIOp>() || t.getDefiningOp<DivSIOp>()) &&
+            (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+             isValidSymbolInt(t.getDefiningOp()->getOperand(1))) &&
+            (!(fix(t.getDefiningOp()->getOperand(0), false) &&
+               fix(t.getDefiningOp()->getOperand(1), false)))) ||
+           (t.getDefiningOp<DivSIOp>() &&
+            (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+             isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+           (t.getDefiningOp<RemUIOp>() &&
+            (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+             isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+           (t.getDefiningOp<RemSIOp>() &&
+            (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+             isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+           t.getDefiningOp<ConstantIntOp>() ||
+           t.getDefiningOp<ConstantIndexOp>())) ||
+         ((decast.getDefiningOp<AddIOp>() || decast.getDefiningOp<SubIOp>() ||
+           decast.getDefiningOp<MulIOp>() || decast.getDefiningOp<RemUIOp>() ||
+           decast.getDefiningOp<RemSIOp>()) &&
+          (decast.getDefiningOp()
+               ->getOperand(1)
+               .getDefiningOp<ConstantIntOp>() ||
+           decast.getDefiningOp()
+               ->getOperand(1)
+               .getDefiningOp<ConstantIndexOp>())))) {
+      t = decast;
+      LLVM_DEBUG(llvm::dbgs() << " Replacing: " << t << "\n");
+
+      AffineMap affineApplyMap;
+      SmallVector<Value, 8> affineApplyOperands;
+
+      // llvm::dbgs() << "\nop to start: " << t << "\n";
+
+      if (auto op = t.getDefiningOp<AddIOp>()) {
+        affineApplyMap =
+            AffineMap::get(0, 2,
+                           getAffineSymbolExpr(0, op.getContext()) +
+                               getAffineSymbolExpr(1, op.getContext()));
+        affineApplyOperands.push_back(op.getLhs());
+        affineApplyOperands.push_back(op.getRhs());
+      } else if (auto op = t.getDefiningOp<SubIOp>()) {
+        affineApplyMap =
+            AffineMap::get(0, 2,
+                           getAffineSymbolExpr(0, op.getContext()) -
+                               getAffineSymbolExpr(1, op.getContext()));
+        affineApplyOperands.push_back(op.getLhs());
+        affineApplyOperands.push_back(op.getRhs());
+      } else if (auto op = t.getDefiningOp<MulIOp>()) {
+        if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value());
+          affineApplyOperands.push_back(op.getLhs());
+        } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value());
+          affineApplyOperands.push_back(op.getLhs());
+        } else {
+          affineApplyMap =
+              AffineMap::get(0, 2,
+                             getAffineSymbolExpr(0, op.getContext()) *
+                                 getAffineSymbolExpr(1, op.getContext()));
+          affineApplyOperands.push_back(op.getLhs());
+          affineApplyOperands.push_back(op.getRhs());
+        }
+      } else if (auto op = t.getDefiningOp<DivSIOp>()) {
+        if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1,
+              getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+          affineApplyOperands.push_back(op.getLhs());
+        } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1,
+              getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+          affineApplyOperands.push_back(op.getLhs());
+        } else {
+          affineApplyMap = AffineMap::get(
+              0, 2,
+              getAffineSymbolExpr(0, op.getContext())
+                  .floorDiv(getAffineSymbolExpr(1, op.getContext())));
+          affineApplyOperands.push_back(op.getLhs());
+          affineApplyOperands.push_back(op.getRhs());
+        }
+      } else if (auto op = t.getDefiningOp<DivUIOp>()) {
+        if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1,
+              getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+          affineApplyOperands.push_back(op.getLhs());
+        } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+          affineApplyMap = AffineMap::get(
+              0, 1,
+              getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+          affineApplyOperands.push_back(op.getLhs());
+        } else {
+          affineApplyMap = AffineMap::get(
+              0, 2,
+              getAffineSymbolExpr(0, op.getContext())
+                  .floorDiv(getAffineSymbolExpr(1, op.getContext())));
+          affineApplyOperands.push_back(op.getLhs());
+          affineApplyOperands.pus...
[truncated]

@rengolin
Copy link
Member

What's the overall motivation behind this raising pass?

@terapines-osc-mlir
Copy link
Contributor Author

What's the overall motivation behind this raising pass?

Some MLIR frontends like ClangIR, Flang and downstream projects generate SCF loops by default.
This raising unlocks advanced analysis and polyhedral-style optimizations available in the Affine dialect.

I have edited the first comment and the RFC link is attached. 🤝

@NexMing
Copy link
Contributor

NexMing commented Sep 15, 2025

Please update commit message https://mlir.llvm.org/getting_started/Contributing/#commit-messages

@wsmoses
Copy link
Member

wsmoses commented Sep 15, 2025

First of all, thank you so much for helping with this -- it would be awesome to see this upstream (as many many projects would benefit from the polygeist raising and other transformations).

The most up to date version of the code currently lives here (and contains raising of not just loops, but memory, etc): https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/AffineCFG.cpp

At this point might I recommend perhaps we just add a separate polygeist dialect upstream and make the requisite pass under that?

cc @Pangoraw @ftynse @ivanradanov

Based on Polygeist commit 77c04bb2a7a2406ca9480bcc9e729b07d2c8d077
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants