From 8414b254a9f7213b30358343e69b9e1e12f9307a Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Tue, 24 Feb 2026 15:18:09 +0100 Subject: [PATCH] Python: Support ternary comparisons in PatternMatchingComparator --- .../src/rewrite/python/template/comparator.py | 17 +++++ .../tests/python/template/test_comparator.py | 64 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/rewrite-python/rewrite/src/rewrite/python/template/comparator.py b/rewrite-python/rewrite/src/rewrite/python/template/comparator.py index 7b9c434cfc0..76f4b770b39 100644 --- a/rewrite-python/rewrite/src/rewrite/python/template/comparator.py +++ b/rewrite-python/rewrite/src/rewrite/python/template/comparator.py @@ -158,6 +158,8 @@ def _compare( return self._compare_assignment(pattern, cast(j.Assignment, target), cursor) elif isinstance(pattern, j.Parentheses): return self._compare_parentheses(pattern, cast(j.Parentheses, target), cursor) + elif isinstance(pattern, j.Ternary): + return self._compare_ternary(pattern, cast(j.Ternary, target), cursor) elif isinstance(pattern, j.Return): return self._compare_return(pattern, cast(j.Return, target), cursor) elif isinstance(pattern, py.ExpressionStatement): @@ -344,6 +346,21 @@ def _compare_assignment( return self._compare(pattern.assignment, target.assignment, cursor) + def _compare_ternary( + self, + pattern: j.Ternary, + target: j.Ternary, + cursor: 'Cursor' + ) -> bool: + """Compare two ternary (conditional) expressions.""" + if not self._compare(pattern.condition, target.condition, cursor): + return False + + if not self._compare(pattern.true_part, target.true_part, cursor): + return False + + return self._compare(pattern.false_part, target.false_part, cursor) + def _compare_parentheses( self, pattern: j.Parentheses, diff --git a/rewrite-python/rewrite/tests/python/template/test_comparator.py b/rewrite-python/rewrite/tests/python/template/test_comparator.py index e0c0d4d52e4..13fa2c3e4bd 100644 --- a/rewrite-python/rewrite/tests/python/template/test_comparator.py +++ b/rewrite-python/rewrite/tests/python/template/test_comparator.py @@ -590,6 +590,70 @@ def test_dict_element_count_mismatch(self): assert result is None +class TestTernaryMatching: + """Tests for ternary (conditional) expression comparison.""" + + def setup_method(self): + TemplateEngine.clear_cache() + + def teardown_method(self): + TemplateEngine.clear_cache() + + def test_placeholder_ternary_captures(self): + """{a} if {cond} else {b} should capture all three parts.""" + captures = {'a': capture('a'), 'cond': capture('cond'), 'b': capture('b')} + pattern_tree = TemplateEngine.get_template_tree("{a} if {cond} else {b}", captures) + target_tree = TemplateEngine.get_template_tree("x if flag else y", {}) + cursor = _make_cursor(target_tree) + + comparator = PatternMatchingComparator(captures) + result = comparator.match(pattern_tree, target_tree, cursor) + assert result is not None + assert 'a' in result + assert 'cond' in result + assert 'b' in result + + def test_concrete_ternary_match(self): + """x if True else y should match x if True else y.""" + pattern_tree = TemplateEngine.get_template_tree("x if True else y", {}) + target_tree = TemplateEngine.get_template_tree("x if True else y", {}) + cursor = _make_cursor(target_tree) + + comparator = PatternMatchingComparator({}) + result = comparator.match(pattern_tree, target_tree, cursor) + assert result is not None + + def test_ternary_condition_mismatch(self): + """x if True else y should not match x if False else y.""" + pattern_tree = TemplateEngine.get_template_tree("x if True else y", {}) + target_tree = TemplateEngine.get_template_tree("x if False else y", {}) + cursor = _make_cursor(target_tree) + + comparator = PatternMatchingComparator({}) + result = comparator.match(pattern_tree, target_tree, cursor) + assert result is None + + def test_ternary_true_part_mismatch(self): + """x if True else y should not match z if True else y.""" + pattern_tree = TemplateEngine.get_template_tree("x if True else y", {}) + target_tree = TemplateEngine.get_template_tree("z if True else y", {}) + cursor = _make_cursor(target_tree) + + comparator = PatternMatchingComparator({}) + result = comparator.match(pattern_tree, target_tree, cursor) + assert result is None + + def test_ternary_false_part_mismatch(self): + """x if True else y should not match x if True else z.""" + pattern_tree = TemplateEngine.get_template_tree("x if True else y", {}) + target_tree = TemplateEngine.get_template_tree("x if True else z", {}) + cursor = _make_cursor(target_tree) + + comparator = PatternMatchingComparator({}) + result = comparator.match(pattern_tree, target_tree, cursor) + assert result is None + + class TestDefaultFallthrough: """Tests for the default comparison behavior on unrecognized types."""