First, we can create the IncorrectComparisonOperatorProvider class, which is a BaseMetadataProvider that will record the original operator and the replacement operator for a comparison node. This provider will be used by the transformer to store the original and replacement operators, so that it can later use that information to update the comparison operator.

Next, we can create the IncorrectComparisonOperatorTransformer class, which is a CSTTransformer that will replace the comparison operator of a comparison node with an incorrect operator. This transformer will use the IncorrectComparisonOperatorProvider to store the original and replacement operators, so that it can later use that information to update the comparison operator. It will also have a visit_Comparison method that will take a comparison node and iterate through its comparisons, check if the operator is a correct operator and if it is, it will use the provider to store the original and replacement operator, and then update the comparison operator with the incorrect operator in the leave_Comparison method.

Finally, we can create the ReverseIncorrectComparisonOperatorTransformer class, which is a CSTTransformer that will reverse the changes made by the IncorrectComparisonOperatorTransformer by replacing the comparison operator of a comparison node with the original operator. This transformer will use the IncorrectComparisonOperatorProvider to retrieve the original and replacement operators, so that it can later use that information to update the comparison operator. It will also have a visit_Comparison method that will take a comparison node and iterate through its comparisons, and then use the provider to retrieve the original and replacement operator, and then update the comparison operator with the original operator in the leave_Comparison method.

In [13]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget, MetadataWrapper, MetadataDependent, BaseMetadataProvider, BatchableMetadataProvider, BaseCompOp
from libcst import matchers
from typing import List, Dict, Union, Tuple, Type
from libcst import Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match
from libcst import matchers as m





In [None]:
import libcst as cst


class IsParamProvider(cst.BatchableMetadataProvider[bool]):
    """
    Marks Name nodes found as a parameter to a function.
    """
    def __init__(self) -> None:
        super().__init__()
        self.is_param = False

    def visit_Param(self, node: cst.Param) -> None:
        # Mark the child Name node as a parameter
        self.set_metadata(node.name, True)

    def visit_Name(self, node: cst.Name) -> None:
        # Mark all other Name nodes as not parameters
        if not self.get_metadata(type(self), node, False):
            self.set_metadata(node, False)

In [2]:
module = cst.parse_module("x")
wrapper = cst.MetadataWrapper(module)

isparam = wrapper.resolve(IsParamProvider)
x_name_node = wrapper.module.body[0].body[0].value

print(isparam[x_name_node])  # should print False

False


In [25]:
class IsComparisonProvider(cst.BatchableMetadataProvider[bool]):
    """
    Marks Name nodes found as a parameter to a function.
    """
    def __init__(self) -> None:
        super().__init__()

    def visit_Comparison(self, node: cst.Comparison) -> None:
        # Mark the child Name node as a parameter
        self.set_metadata(node, True)



In [27]:
code = """
if 1 <= 2:
    pass
elif 2 > 3:
    pass
else:
    pass
"""
module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
iscomparison = wrapper.resolve(IsComparisonProvider)
for key, value in iscomparison.items():
    print(key, value)


Comparison(
    left=Integer(
        value='1',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=LessThanEqual(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=Integer(
                value='2',
                lpar=[],
                rpar=[],
            ),
        ),
    ],
    lpar=[],
    rpar=[],
) True
Comparison(
    left=Integer(
        value='2',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=GreaterThan(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=Integer(
                val