[MLIR][test] Add lit coverage for cf.br/cond_br/switch under narrow-type emulation#198053
Merged
Conversation
|
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Alan Li (lialan) Changes
Full diff: https://github.com/llvm/llvm-project/pull/198053.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 4d6c54d74d2a9..0df84bad2dd65 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLFunctionalExtras.h"
namespace mlir {
+class ConversionTarget;
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
@@ -104,6 +105,15 @@ void populateMemRefNarrowTypeEmulationPatterns(
void populateMemRefNarrowTypeEmulationConversions(
arith::NarrowTypeEmulationConverter &typeConverter);
+/// Register patterns + dynamic legality so that cf branch ops carrying
+/// memref values whose element type is being emulated have both their
+/// operand types and their successor block-argument types rewritten to the
+/// container element type. Thin wrapper over
+/// cf::populateCFStructuralTypeConversionsAndLegality.
+void populateMemRefNarrowTypeEmulationCFPatterns(
+ const arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns, ConversionTarget &target);
+
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 1c5e07f89b338..e5e90e3f5ff88 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRAffineUtils
MLIRArithDialect
MLIRArithTransforms
+ MLIRControlFlowTransforms
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUDialect
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index a11e14faa5475..51a147f5fa79f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
@@ -813,3 +814,10 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
newElemTy, layoutAttr, ty.getMemorySpace());
});
}
+
+void memref::populateMemRefNarrowTypeEmulationCFPatterns(
+ const arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns, ConversionTarget &target) {
+ cf::populateCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
+ target);
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir
new file mode 100644
index 0000000000000..dc67b776553a0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 arith-compute-bitwidth=1" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
+
+// Sub-byte memref type carried through cf.br block args. The cf branch
+// pattern (registered by cf::populateCFStructuralTypeConversionsAndLegality)
+// must rewrite both the cf.br operand type and the successor block-arg type
+// to the i8 container, so the downstream uses in the successor block see an
+// i8 source.
+
+// CHECK-LABEL: func.func @cf_br_block_arg_narrow_type
+// CHECK-SAME: %[[ARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK: cf.br ^[[BB1:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK: ^[[BB1]](%[[BARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[BARG]]
+// CHECK-NOT: memref<{{[0-9]+}}xi4>
+func.func @cf_br_block_arg_narrow_type(%arg: memref<8xi4>) -> memref<8xi4> {
+ cf.br ^bb1(%arg : memref<8xi4>)
+^bb1(%a: memref<8xi4>):
+ return %a : memref<8xi4>
+}
+
+// -----
+
+// Sub-byte memref carried through both successors of a cf.cond_br. Both
+// branch operand types and both successor block-arg types must be rewritten
+// to the i8 container.
+
+// CHECK-LABEL: func.func @cf_cond_br_block_arg_narrow_type
+// CHECK-SAME: %[[COND:[A-Za-z0-9_]+]]: i1
+// CHECK-SAME: %[[A:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK-SAME: %[[B:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK: cf.cond_br %[[COND]], ^[[BBT:.+]](%[[A]] : memref<{{[0-9]+}}xi8>), ^[[BBF:.+]](%[[B]] : memref<{{[0-9]+}}xi8>)
+// CHECK: ^[[BBT]](%[[XT:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[XT]]
+// CHECK: ^[[BBF]](%[[XF:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[XF]]
+// CHECK-NOT: memref<{{[0-9]+}}xi4>
+func.func @cf_cond_br_block_arg_narrow_type(%cond: i1, %a: memref<8xi4>, %b: memref<8xi4>) -> memref<8xi4> {
+ cf.cond_br %cond, ^bb1(%a : memref<8xi4>), ^bb2(%b : memref<8xi4>)
+^bb1(%x: memref<8xi4>):
+ return %x : memref<8xi4>
+^bb2(%y: memref<8xi4>):
+ return %y : memref<8xi4>
+}
+
+// -----
+
+// Sub-byte memref carried through the default and case successors of a
+// cf.switch. The branch pattern must rewrite the operand type at every
+// successor edge and the matching block-arg type at every successor.
+
+// CHECK-LABEL: func.func @cf_switch_block_arg_narrow_type
+// CHECK-SAME: %[[FLAG:[A-Za-z0-9_]+]]: i32
+// CHECK-SAME: %[[ARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK: cf.switch %[[FLAG]] : i32, [
+// CHECK: default: ^[[BBD:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK: 0: ^[[BB0:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK: 1: ^[[BB1:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK: ]
+// CHECK: ^[[BBD]](%[[XD:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[XD]]
+// CHECK: ^[[BB0]](%[[X0:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[X0]]
+// CHECK: ^[[BB1]](%[[X1:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK: return %[[X1]]
+// CHECK-NOT: memref<{{[0-9]+}}xi4>
+func.func @cf_switch_block_arg_narrow_type(%flag: i32, %arg: memref<8xi4>) -> memref<8xi4> {
+ cf.switch %flag : i32, [
+ default: ^bb1(%arg : memref<8xi4>),
+ 0: ^bb2(%arg : memref<8xi4>),
+ 1: ^bb3(%arg : memref<8xi4>)
+ ]
+^bb1(%x: memref<8xi4>):
+ return %x : memref<8xi4>
+^bb2(%y: memref<8xi4>):
+ return %y : memref<8xi4>
+^bb3(%z: memref<8xi4>):
+ return %z : memref<8xi4>
+}
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index bec83a8dcbef9..5465c5ccbc610 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -35,8 +36,9 @@ struct TestEmulateNarrowTypePass
void getDependentDialects(DialectRegistry ®istry) const override {
registry
- .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
- vector::VectorDialect, affine::AffineDialect>();
+ .insert<arith::ArithDialect, cf::ControlFlowDialect, func::FuncDialect,
+ memref::MemRefDialect, vector::VectorDialect,
+ affine::AffineDialect>();
}
StringRef getArgument() const final { return "test-emulate-narrow-int"; }
StringRef getDescription() const final {
@@ -104,6 +106,9 @@ struct TestEmulateNarrowTypePass
vector::populateVectorNarrowTypeEmulationPatterns(
typeConverter, patterns, disableAtomicRMW, assumeAligned);
+ memref::populateMemRefNarrowTypeEmulationCFPatterns(
+ typeConverter, patterns, target);
+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
@@ -133,6 +138,7 @@ struct TestEmulateNarrowTypePass
llvm::cl::desc("assume store offsets are aligned to container element "
"boundaries"),
llvm::cl::init(false)};
+
};
struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
krzysz00
reviewed
May 18, 2026
…e emulation Wire cf::populateCFStructuralTypeConversionsAndLegality through the in-tree narrow-type emulation test pass so callers can rewrite cf.br / cf.cond_br / cf.switch operand and successor block-argument types when emulating sub-byte element types. Add lit coverage that exercises memref<NxiW> across cf.br / cf.cond_br / cf.switch, sub-byte integer scalars across cf.br, and sub-byte integer vectors across cf.br.
pedroMVicente
pushed a commit
to pedroMVicente/llvm-project
that referenced
this pull request
May 19, 2026
…ype emulation (llvm#198053) Wires `cf::populateCFStructuralTypeConversionsAndLegality` into the in-tree `TestEmulateNarrowType` pass and adds lit coverage that exercises `cf.br` / `cf.cond_br` / `cf.switch` operand and successor block-argument rewriting when emulating sub-byte element types: * `memref<NxiW>` carried across `cf.br` / `cf.cond_br` / `cf.switch`. * Sub-byte integer scalars across `cf.br`. * Sub-byte integer vectors across `cf.br`. This PR initially added thin wrapper functions (`memref::populateMemRefNarrowTypeEmulationCFPatterns`, `vector::populateVectorNarrowTypeEmulationCFPatterns`) over `cf::populateCFStructuralTypeConversionsAndLegality`. Per review feedback those wrappers were redundant, so callers (including the in-tree test pass) now call `cf::populateCFStructuralTypeConversionsAndLegality` directly. Net contribution is the test-pass plumbing and the new lit tests demonstrating that the existing cf structural type conversion correctly handles narrow-type-emulated values.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Wires
cf::populateCFStructuralTypeConversionsAndLegalityinto thein-tree
TestEmulateNarrowTypepass and adds lit coverage thatexercises
cf.br/cf.cond_br/cf.switchoperand and successorblock-argument rewriting when emulating sub-byte element types:
memref<NxiW>carried acrosscf.br/cf.cond_br/cf.switch.cf.br.cf.br.This PR initially added thin wrapper functions
(
memref::populateMemRefNarrowTypeEmulationCFPatterns,vector::populateVectorNarrowTypeEmulationCFPatterns) overcf::populateCFStructuralTypeConversionsAndLegality. Per reviewfeedback those wrappers were redundant, so callers (including the
in-tree test pass) now call
cf::populateCFStructuralTypeConversionsAndLegalitydirectly. Net contribution is the test-pass plumbing and the new lit
tests demonstrating that the existing cf structural type conversion
correctly handles narrow-type-emulated values.