First, we can create the IncorrectComparisonOperatorProvider class, which is a BaseMetadataProvider that will record the original operator and the replacement operator for a comparison original_node. This provider will be used by the transformer to store the original and replacement operators, so that it can later use that information to update the comparison operator.

Next, we can create the IncorrectComparisonOperatorTransformer class, which is a CSTTransformer that will replace the comparison operator of a comparison node with an incorrect operator. This transformer will use the IncorrectComparisonOperatorProvider to store the original and replacement operators, so that it can later use that information to update the comparison operator. It will also have a visit_Comparison method that will take a comparison node and iterate through its comparisons, check if the operator is a correct operator and if it is, it will use the provider to store the original and replacement operator, and then update the comparison operator with the incorrect operator in the leave_Comparison method.

Finally, we can create the ReverseIncorrectComparisonOperatorTransformer class, which is a CSTTransformer that will reverse the changes made by the IncorrectComparisonOperatorTransformer by replacing the comparison operator of a comparison node with the original operator. This transformer will use the IncorrectComparisonOperatorProvider to retrieve the original and replacement operators, so that it can later use that information to update the comparison operator. It will also have a visit_Comparison method that will take a comparison node and iterate through its comparisons, and then use the provider to retrieve the original and replacement operator, and then update the comparison operator with the original operator in the leave_Comparison method.

In [24]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget, MetadataWrapper, MetadataDependent, BaseMetadataProvider, BatchableMetadataProvider, BaseCompOp
from libcst import matchers
from typing import List, Dict, Union, Tuple, Type
from libcst import Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match
from libcst import matchers as m
from libcst.metadata import PositionProvider



import libcst as cst

In [2]:
class IsComparisonProvider(cst.BatchableMetadataProvider[Dict]):
    """
    Marks Name nodes found as Comparison Operators.
    """
    def __init__(self) -> None:
        super().__init__()

    def visit_Comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> None:
        # Mark the child Name node as a parameter
        self.set_metadata(node, dict(comparison=original_node.comparisons[0], 
                                    original=original_node.comparisons[0].operator.__class__))

                                
                                  
class ComparisonPrinter(cst.CSTVisitor):
    METADATA_DEPENDENCIES = (IsComparisonProvider, PositionProvider,)

    def visit_Comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> None:
        # Only print out names that are parameters
        if self.get_metadata(IsComparisonProvider, node):
            pos = self.get_metadata(PositionProvider, node).start
            print(f"{node} found at line {pos.line}, column {pos.column}")

# the next class should be a generic class that takes the operator to change as a parameter
class ComparisonMutator(cst.CSTTransformer):
    METADATA_DEPENDENCIES = (IsComparisonProvider, PositionProvider,)

    def __init__(self, operator: Type[BaseCompOp]) -> None:
        super().__init__()
        self.operator = operator

    def leave_Comparison(self, original_original_node: cst.Comparison, updated_node: cst.Comparison, updated_original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        # Only print out names that are parameters
        if self.get_metadata(IsComparisonProvider, original_node):
            pos = self.get_metadata(PositionProvider, original_node).start
            # print(f"{original_node} found at line {pos.line}, column {pos.column}")
            return updated_updated_node.with_changes(comparisons=[updated_original_node.comparisons[0].with_changes(operator=self.operator())])
        return updated_node        

In [4]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
comparators  = [Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match]
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
provider = IsComparisonProvider
data = wrapper.resolve(provider)
visitor = ComparisonPrinter()
wrapper.visit(visitor)
mutator = ComparisonMutator(Equal)
new_module = wrapper.visit(mutator)
print(new_module.code)


if 1 == 2:
    pass
elif 2 == 3:
    pass
else:
    pass



In [29]:
import libcst as cst
from libcst.matchers import visit, leave

class ComparisonMutatorDecorator(cst.matchers.MatcherDecoratableTransformer):
    def __init__(self, operator: Type[BaseCompOp]) -> None:
        super().__init__()
        self.operator = operator
        
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.Equal()]))
    def update_equal_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.GreaterThan()]))
    def update_greater_than_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.LessThan()]))
    def update_less_than_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.GreaterThanEqual()]))
    def update_greater_than_equal_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        print("found greater than equal")
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.LessThanEqual()])) 
    def update_less_than_equal_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.NotEqual()]))
    def update_not_equal_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.NotIn()]))
    def update_not_in_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.In()]))
    def update_in_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.Is()]))
    def update_is_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.IsNot()]))
    def update_is_not_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.Not()]))
    def update_not_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.And()]))
    def update_and_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.Or()]))
    def update_or_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        return updated_node.with_changes(comparisons=[original_node.comparisons[0].with_changes(operator=self.operator())])


In [34]:
code = """
if 1 >= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
comparators  = [Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match]
module = cst.parse_module(code)
mutator = ComparisonMutatorDecorator(comparators[0])
mutated_module = module.visit(mutator)
print(mutated_module.code)


if 1 >= 2:
    pass
elif 2 > 3:
    pass
else:
    pass



In [1]:
from libcst.matchers import  visit
import libcst as cst

class ComparisonPrinter(cst.matchers.MatcherDecoratableVisitor):
    @leave(cst.matchers.Comparison(comparisons=[cst.matchers.Equal()]))
    def print_equal_comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> None:
        pos = self.get_metadata(cst.metadata.PositionProvider, node).start
        print(f"Equal comparison found at line {pos.line}, column {pos.column}")


In [21]:
import libcst as cst

class MatcherPrinter(cst.matchers.MatcherDecoratableVisitor):
    def __init__(self, matcher: cst.matchers.BaseCompOp) -> None:
        self.matcher = matcher
        self.visit_functions = [self.handle_matching_node]
        
    def on_visit(self, node: cst.CSTNode) -> bool:
        if cst.matchers.Comparison(comparisons=[cst.matchers.Equal()]):
            return True
        return False
        
    def handle_matching_node(self, node: cst.CSTNode) -> None:
        pos = self.get_metadata(cst.metadata.PositionProvider, node).start
        print(f"A node matching {self.matcher} found at line {pos.line}, column {pos.column}")


In [22]:
code = """
if 1 == 2:
    pass
elif 3 == 4:
    pass
else:
    pass
"""

module = cst.parse_module(code)
visitor = MatcherPrinter(cst.matchers.Equal())
module.visit(visitor)

AttributeError: 'MatcherPrinter' object has no attribute '_matchers'

In [21]:
code = """
if 1 == 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
provider = IsEqualProvider
data = wrapper.resolve(provider)
visitor = EqualVisitor()
wrapper.visit(visitor)
# visitor = ComparisonPrinter()
# wrapper.visit(visitor)
# mutator = ComparisonMutatorNotEqual()
# new_module = wrapper.visit(mutator)
# print(new_module.code)

Comparison(
    left=Integer(
        value='1',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=Equal(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=Integer(
                value='2',
                lpar=[],
                rpar=[],
            ),
        ),
    ],
    lpar=[],
    rpar=[],
)


TypeError: 'ABCMeta' object is not iterable

In [None]:
class IsComparisonTypeProvider(cst.BatchableMetadataProvider[Dict]):
    """
    Marks Name nodes found as a parameter to a function.
    """
    def __init__(self) -> None:
        super().__init__()
        #operator dict maps from strings to operators
        self.operator_dict = {'==': Equal, '!=': NotEqual, '>=': GreaterThanEqual, '<=': LessThanEqual, '>': GreaterThan, '<': LessThan, 'in': In, 'not in': NotIn, 'is': Is, 'is not': IsNot, 'and': And, 'or': Or, 'not': Not, 'match': Match}

    def set_operator(self, operator: str) -> None:
        self.operator = self.operator_dict[operator]    


    def visit_Comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> None:
        # Mark the child Name node as a parameter
        if original_node.comparisons[0].operator.__class__ == self.operator:
            self.set_metadata(node, dict(comparison=original_node.comparisons[0], 
                                        original=original_node.comparisons[0].operator.__class__))  

class ComparisonTypePrinter(cst.CSTVisitor):
    METADATA_DEPENDENCIES = (IsComparisonTypeProvider, PositionProvider,)

    def visit_Comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> None:
        # Only print out names that are parameters
        if self.get_metadata(IsComparisonTypeProvider, node):
            pos = self.get_metadata(PositionProvider, node).start
            print(f"{node} found at line {pos.line}, column {pos.column}")
                                          

In [None]:
# class IsParamProvider(cst.BatchableMetadataProvider[bool]):
#     """
#     Marks Name nodes found as a parameter to a function.
#     """
#     def __init__(self) -> None:
#         super().__init__()
#         self.is_param = False

#     def visit_Param(self, node: cst.Param) -> None:
#         # Mark the child Name node as a parameter
#         self.set_metadata(original_node.name, True)

#     def visit_Name(self, node: cst.Name) -> None:
#         # Mark all other Name nodes as not parameters
#         if not self.get_metadata(type(self), node, False):
#             self.set_metadata(node, False)
# class ParamPrinter(cst.CSTVisitor):
#     METADATA_DEPENDENCIES = (IsParamProvider, PositionProvider,)

#     def visit_Name(self, node: cst.Name) -> None:
#         # Only print out names that are parameters
#         if self.get_metadata(IsParamProvider, node):
#             pos = self.get_metadata(PositionProvider, node).start
#             print(f"{original_node.value} found at line {pos.line}, column {pos.column}")

In [None]:
# module = cst.parse_module("x")
# wrapper = cst.MetadataWrapper(module)

# isparam = wrapper.resolve(IsParamProvider)
# x_name_node = wrapper.module.body[0].body[0].value

# print(isparam[x_name_node])  # should print False
# module = cst.parse_module("def foo(x):\n    y = 1\n    return x + y")
# wrapper = cst.MetadataWrapper(module)
# result = wrapper.visit(ParamPrinter())  # NB: wrapper.visit not module.visit

In [None]:
#comparison transformer should be able to replace the comparisons in the code with the new ones
class ComparisonTransformer(cst.CSTTransformer):
    METADATA_DEPENDENCIES = (IsComparisonProvider, PositionProvider,)

    def leave_Comparison(self, original_original_node: cst.Comparison, updated_node: cst.Comparison, updated_original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.Comparison:
        # Only print out names that are parameters
        if self.get_metadata(IsComparisonProvider, original_node):
            pos = self.get_metadata(PositionProvider, original_node).start
            print(f"{original_node} found at line {pos.line}, column {pos.column}")
            return updated_updated_node.with_changes(comparators=[cst.ComparisonTarget(cst.Name("z"))])
        return updated_node

In [None]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)

# Mark all Comparison nodes as True
wrapper.visit(ComparisonPrinter())

Module(
    body=[
        If(
            test=Comparison(
                left=Integer(
                    value='1',
                    lpar=[],
                    rpar=[],
                ),
                comparisons=[
                    ComparisonTarget(
                        operator=LessThanEqual(
                            whitespace_before=SimpleWhitespace(
                                value=' ',
                            ),
                            whitespace_after=SimpleWhitespace(
                                value=' ',
                            ),
                        ),
                        comparator=Integer(
                            value='2',
                            lpar=[],
                            rpar=[],
                        ),
                    ),
                ],
                lpar=[],
                rpar=[],
            ),
            body=IndentedBlock(
                body=[
                    SimpleStatementLin

In [None]:

class IncorrectComparisonOperatorTransformer(CSTTransformer):
    def __init__(self, correct_operators, incorrect_operators):
        self.correct_operators = correct_operators
        self.incorrect_operators = incorrect_operators
        self.metadata = {}
        self.visit_functions = [self.visit_Comparison]


    def visit_Comparison(self, node: Comparison):
        for comparison in original_node.comparisons:
            for correct, incorrect in zip(self.correct_operators, self.incorrect_operators):
                if isinstance(comparison.operator, correct):
                    self.metadata[comparison] = (correct, incorrect)

    def leave_Comparison(self, original_node: Comparison, updated_node: Comparison) -> Comparison:
        for comparison, (correct, incorrect) in self.metadata.items():
            updated_comparison = comparison.with_changes(operator=incorrect())
            updated_node = original_updated_node.with_changes(comparisons=[updated_comparison])
        return updated_node

In [None]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)

# Mark all Comparison nodes as True
wrapper.resolve(IsComparisonProvider)

# Replace incorrect operators with correct ones
correct_operators = [cst.LessThanEqual, cst.GreaterThan]
incorrect_operators = [cst.LessThan, cst.GreaterThanEqual]
transformer = IncorrectComparisonOperatorTransformer(correct_operators, incorrect_operators)
wrapper.visit(transformer)

# Get the updated code
updated_code = wrapper.module.code
print(updated_code)



if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass



In [None]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""

module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)

# Mark all Comparison nodes as True
iscomparison = wrapper.resolve(IsComparisonProvider)

# Replace incorrect operators with correct ones
correct_operators = [cst.LessThanEqual, cst.GreaterThan]
incorrect_operators = [cst.LessThan, cst.GreaterThanEqual]
transformer = IncorrectComparisonOperatorTransformer(correct_operators, incorrect_operators)
wrapper.visit(transformer, filter_fn=lambda n: iscomparison.get(n))

# Get the updated code
updated_code = wrapper.module.code
print(updated_code)


TypeError: MetadataWrapper.visit() got an unexpected keyword argument 'filter_fn'

In [None]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
iscomparison = wrapper.resolve(IsComparisonProvider)
for key, value in iscomparison.items():
    print(key, value)


Comparison(
    left=Integer(
        value='1',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=LessThanEqual(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=Integer(
                value='2',
                lpar=[],
                rpar=[],
            ),
        ),
    ],
    lpar=[],
    rpar=[],
) True
Comparison(
    left=Integer(
        value='2',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=GreaterThan(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=Integer(
                val