diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1a7eb46f752921..1a9604882fe09c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2363,8 +2363,6 @@ void mlir::python::populateIRCore(py::module &m) { [](const std::vector &pyLocations, llvm::Optional metadata, DefaultingPyMlirContext context) { - if (pyLocations.empty()) - throw py::value_error("No locations provided"); llvm::SmallVector locations; locations.reserve(pyLocations.size()); for (auto &pyLocation : pyLocations) diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index 1de4d73bbe4fa0..ce88b244a90df8 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -106,10 +106,18 @@ Location FusedLoc::get(ArrayRef locs, Attribute metadata, } locs = decomposedLocs.getArrayRef(); - // Handle the simple cases of less than two locations. - if (locs.empty()) - return UnknownLoc::get(context); - if (locs.size() == 1) + // Handle the simple cases of less than two locations. Ensure the metadata (if + // provided) is not dropped. + if (locs.empty()) { + if (!metadata) + return UnknownLoc::get(context); + // TODO: Investigate ASAN failure when using implicit conversion from + // Location to ArrayRef below. + return Base::get(context, ArrayRef{UnknownLoc::get(context)}, + metadata); + } + if (locs.size() == 1 && !metadata) return locs.front(); + return Base::get(context, locs, metadata); } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 0016c3ec6611be..f6c4f21cfd7c6d 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -21,6 +21,10 @@ func @inline_notation() -> i32 { affine.if #set0(%2) { } loc(fused<"myPass">["foo", "foo2"]) + // CHECK: } loc(fused<"myPass">["foo"]) + affine.if #set0(%2) { + } loc(fused<"myPass">["foo"]) + // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) } diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py index 1c13c4870cbe06..ecdd02efb0aee2 100644 --- a/mlir/test/python/ir/location.py +++ b/mlir/test/python/ir/location.py @@ -78,12 +78,20 @@ def testCallSite(): # CHECK-LABEL: TEST: testFused def testFused(): with Context() as ctx: + loc_single = Location.fused([Location.name("apple")]) loc = Location.fused( [Location.name("apple"), Location.name("banana")]) attr = Attribute.parse('"sauteed"') loc_attr = Location.fused([Location.name("carrot"), Location.name("potatoes")], attr) + loc_empty = Location.fused([]) + loc_empty_attr = Location.fused([], attr) + loc_single_attr = Location.fused([Location.name("apple")], attr) ctx = None + # CHECK: file str: loc("apple") + print("file str:", str(loc_single)) + # CHECK: file repr: loc("apple") + print("file repr:", repr(loc_single)) # CHECK: file str: loc(fused["apple", "banana"]) print("file str:", str(loc)) # CHECK: file repr: loc(fused["apple", "banana"]) @@ -92,6 +100,18 @@ def testFused(): print("file str:", str(loc_attr)) # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"]) print("file repr:", repr(loc_attr)) + # CHECK: file str: loc(unknown) + print("file str:", str(loc_empty)) + # CHECK: file repr: loc(unknown) + print("file repr:", repr(loc_empty)) + # CHECK: file str: loc(fused<"sauteed">[unknown]) + print("file str:", str(loc_empty_attr)) + # CHECK: file repr: loc(fused<"sauteed">[unknown]) + print("file repr:", repr(loc_empty_attr)) + # CHECK: file str: loc(fused<"sauteed">["apple"]) + print("file str:", str(loc_single_attr)) + # CHECK: file repr: loc(fused<"sauteed">["apple"]) + print("file repr:", repr(loc_single_attr)) run(testFused)