In [1]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget
from libcst import matchers
from typing import List, Dict, Union, Tuple, Type
import re
from libcst import Equal, GreaterThanEqual
from libcst import matchers as m

## REFERENCES

- [Nodes](https://libcst.readthedocs.io/en/latest/nodes.html#libcst-nodes)

In [4]:
# Use difflib to show the changes to verify type annotations were added as expected.
import difflib

def printdiff(vanilla, tainted):
    return (
        "".join(
            difflib.unified_diff(vanilla.splitlines(1), tainted.splitlines(1))
        )
    )

In [210]:
import libcst as cst


def gen_transfomers(op1, op2):
    class IsComparisonProvider(cst.BatchableMetadataProvider[Dict]):
    

        def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> None:
            # Mark the node as an equal comparison node
            target = op2() if node.operator.__class__ == op1 else node.operator.__class__
            self.set_metadata(node, dict(comparison=node, 
                                    original=node.operator.__class__,target=target))
                         
    
    class ApplyTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (IsComparisonProvider, )
    
        def leave_ComparisonTarget(self, original_node:cst.ComparisonTarget, updated_node: cst.ComparisonTarget) -> None:
            meta = self.get_metadata(IsComparisonProvider, original_node)
            if meta['original'] != meta['target']:
                updated_node = meta['comparison'].with_changes(operator=meta['target']) # OP2
            return updated_node
            
        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__
    
    class ReverseTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (IsComparisonProvider, )

        def leave_ComparisonTarget(self, original_node:cst.ComparisonTarget, updated_node: cst.ComparisonTarget) -> None:
            meta = self.get_metadata(IsComparisonProvider, original_node)
            if meta:
                updated_node = meta['comparison'].with_changes(operator=meta['original']())
            return updated_node
        
        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__

    return ApplyTransformer, ReverseTransformer    

In [212]:
## Test
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections

str2op = dict([
    ('==', Equal),
    ('>=', GreaterThanEqual),
    ('>', GreaterThan),
    ('<', LessThan),
    ('=<', LessThanEqual),
    ('!=', NotEqual),
    ('not in', NotIn),
    ('in', In),
    ('is', Is),
    ('is not', IsNot),
    ('not', Not),
    ('and', And),
    ('or', Or),
    ('or', Or),
    # ('match', Match),
])

op2str = {v:k for k,v in str2op.items()}

# Get the script as a string
script = "x == 1 + 2 != 3 + 4 == 3"

# Parse the script into a CST
module = cst.parse_module(script)

# Use the gen_transfomers function to generate the ApplyTransformer and ReverseTransformer classes
ApplyTransformer, ReverseTransformer = gen_transfomers(str2op['=='], str2op['>'])

wrapper = cst.MetadataWrapper(module)
# Apply the ApplyTransformer to the CST
tainted = wrapper.visit(ApplyTransformer())
reverted = wrapper.visit(ReverseTransformer())

print(tainted.code)
print(reverted.code)

x > 1 + 2 != 3 + 4 > 3
x == 1 + 2 != 3 + 4 == 3


In [71]:
wrapper._metadata

{__main__.gen_transfomers.<locals>.IsComparisonProvider: mappingproxy({ComparisonTarget(
                   operator=Equal(
                       whitespace_before=SimpleWhitespace(
                           value=' ',
                       ),
                       whitespace_after=SimpleWhitespace(
                           value=' ',
                       ),
                   ),
                   comparator=BinaryOperation(
                       left=Integer(
                           value='1',
                           lpar=[],
                           rpar=[],
                       ),
                       operator=Add(
                           whitespace_before=SimpleWhitespace(
                               value=' ',
                           ),
                           whitespace_after=SimpleWhitespace(
                               value=' ',
                           ),
                       ),
                       right=Integer(
                   

In [63]:

class NodeCollector(cst.CSTVisitor):
    def __init__(self):
        self.nodes = set()
    def on_leave(self, original_node: cst.CSTNode) -> None:
        self.nodes.add(original_node)
        
    def get_nodes(self):
        return list(self.nodes)
collector = NodeCollector()
module.visit(collector)
nodes_module = collector.get_nodes()
collector_tainted = NodeCollector()
tainted.visit(collector_tainted)
nodes_tainted = collector_tainted.get_nodes()

In [60]:
print(nodes_tainted[3])
print(nodes_module[0])
print(nodes_tainted[0].deep_equals(nodes_module[0]))
node_counters = []
for node in nodes_tainted:
    node_counter=0
    for node2 in nodes_module:
        if node.deep_equals(node2):
            # print("equal")
            # print(node, node2)
            node_counter+=1
    node_counters.append(node_counter)
print(node_counters)    

SimpleWhitespace(
    value=' ',
)
Name(
    value='x',
    lpar=[],
    rpar=[],
)
True
[1, 1, 1, 6, 1, 1, 6, 6, 6, 0, 0, 1, 6, 0, 0, 0, 0, 1, 0, 0, 1, 6, 1]


In [62]:
for i,node in enumerate(nodes_tainted):
    if not node.deep_equals(nodes_module[i]):
        print("not equal")
        print(node, nodes_module[i])

not equal
Add(
    whitespace_before=SimpleWhitespace(
        value=' ',
    ),
    whitespace_after=SimpleWhitespace(
        value=' ',
    ),
) SimpleWhitespace(
    value=' ',
)
not equal
Integer(
    value='3',
    lpar=[],
    rpar=[],
) Comparison(
    left=Name(
        value='x',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=Equal(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value=' ',
                ),
            ),
            comparator=BinaryOperation(
                left=Integer(
                    value='1',
                    lpar=[],
                    rpar=[],
                ),
                operator=Add(
                    whitespace_before=SimpleWhitespace(
                        value=' ',
                    ),
                    whitespace_after=SimpleWhitespace(
 

In [167]:
def composite_factory(mutators: List[Transform]):
    apply_transforms = [m[0] for m in mutators]
    reverse_transforms = [m[1] for m in mutators]

    METADATA_PROVIDERS = list(itertools.chain(*itertools.chain(t.METADATA_DEPENDENCIES for t in apply_transforms)))
    print(METADATA_PROVIDERS)
    
    class CompositeTaintOperatorProvider(cst.BatchableMetadataProvider[Dict]):
    
        def visit_Comparison(self, node: cst.Param) -> None:
            for p in METADATA_PROVIDERS:
                if not p.MATCHER(node): continue

                meta = self.get_metadata(node)
                if meta is None:
                    p.visit_Comparison(node) # apply the matcher
                    # Mark the node as an equal comparison node
                    self.set_metadata(node, {'tainted': True})
                elif meta.get('tainted') is None:
                    p.visit_Comparison(node)
                else:
                    pass # tainted already
                
    class ApplyCompositeTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (CompositeTaintOperatorProvider,)

        def leave_Comparison(self, vanilla:cst.Comparison, tainted: cst.Comparison) -> None:
            for apply in apply_transforms:
                vanilla = tainted = apply().leave_Comparison(vanilla, tainted) 
                
            return tainted

    class ReverseCompositeTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (CompositeTaintOperatorProvider,)

        def leave_Comparison(self, vanilla:cst.Comparison, tainted: cst.Comparison) -> None:
            for reverse in reverse_transforms:
                vanilla = tainted = reverse().leave_Comparison(vanilla, tainted) 
                
            return tainted

    return ApplyCompositeTransformer, ReverseCompositeTransformer

In [168]:
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections

str2op = dict([
    ('==', Equal),
    ('>=', GreaterThanEqual),
    ('>', GreaterThan),
    ('<', LessThan),
    ('=<', LessThanEqual),
    ('!=', NotEqual),
    ('not in', NotIn),
    ('in', In),
    ('is', Is),
    ('is not', IsNot),
    ('not', Not),
    ('and', And),
    ('or', Or),
    ('or', Or),
    # ('match', Match),
])

op2str = {v:k for k,v in str2op.items()}

Transform = collections.namedtuple('Transform', ['apply', 'reverse'])
transforms = collections.defaultdict(dict)

for op1 in [ Equal , GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or,]:
    for op2 in [ Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or,]:
        if op1 is op2: continue
        apply, reverse = gen_transfomers(op1, op2)
        transforms[op2str[op1]][op2str[op2]] = Transform(apply, reverse)


In [169]:
len(transforms)

13

In [170]:
import random
random.choice(transforms)

{}

In [163]:
sequence_of_transforms_to_apply = (
    "x >= z and x == y" ,
    [transforms['==']['>='],
     transforms['and']['or']],
    "x >= z or x >= y" ,
)

# pre_wrapper = cst.MetadataWrapper(module)
# change_equal_to_op1, _ = gen_transfomers(Equal, op1)
# new_code = pre_wrapper.visit(change_equal_to_op1)
# print(new_code.code)
# print(op1.__name__, '--->', op2.__name__)
# print('  :', code, )
# print('->:',bugged.code, )
# vanilla = wrapper.visit(reverse())
# print('<-:', vanilla.code)
# print( '---')

code, TLIST, result =  sequence_of_transforms_to_apply

module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
apply, reverse = composite_factory(TLIST)
bugged = wrapper.visit(apply())
apply

TypeError: descriptor '__repr__' of 'object' object needs an argument

In [99]:
# printdiff(cst.parse_expression("1 == 2").code, cst.parse_express('1 >= 2').code)