In [89]:
!pip install libcst



In [90]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget
from libcst import matchers
from typing import List, Dict, Union, Tuple, Type
import re
from libcst import Equal, GreaterThanEqual
from libcst import matchers as m

## REFERENCES

- [Nodes](https://libcst.readthedocs.io/en/latest/nodes.html#libcst-nodes)

In [91]:
# Use difflib to show the changes to verify type annotations were added as expected.
import difflib

def printdiff(vanilla, tainted):
    return (
        "".join(
            difflib.unified_diff(vanilla.splitlines(1), tainted.splitlines(1))
        )
    )

In [165]:
import libcst as cst


def gen_transfomers(op1, op2):
    class IsOperatorProvider(cst.BatchableMetadataProvider[Dict]):
    
        MATCHER = m.call_if_inside(m.BinaryOperation(operator=[getattr(m, op1.__name__)()]))

        def visit_Comparison(self, node: cst.Param) -> None:
            if not self.MATCHER.match(node): return
            # Mark the node as an equal comparison node
            self.set_metadata(node, dict(comparison=node.comparisons[0], 
                                    original=node.comparisons[0].operator.__class__))
    
    class ApplyTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (IsOperatorProvider, )
    
        def leave_Comparison(self, vanilla:cst.Name, tainted: cst.Name) -> None:
            meta = self.get_metadata(IsOperatorProvider, vanilla)
            if meta:
                updated_comparison = meta['comparison'].with_changes(operator=op2()) # OP2
                tainted = vanilla.with_changes(comparisons=[updated_comparison])
            return tainted

        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__
    
    class ReverseTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (IsOperatorProvider, )

        def leave_Comparison(self, tainted:cst.Name, vanilla: cst.Name) -> None:
            meta = self.get_metadata(IsOperatorProvider, tainted)
            if meta:
                updated_comparison = meta['comparison'].with_changes(operator=meta['original']())
                vanilla = tainted.with_changes(comparisons=[updated_comparison])
            return vanilla
        
        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__

    return ApplyTransformer, ReverseTransformer    

In [166]:
import itertools
itertools.chain([1,2,3], [4,5,6])

<itertools.chain at 0x7fbf38d033d0>

In [167]:
def composite_factory(mutators: List[Transform]):
    apply_transforms = [m[0] for m in mutators]
    reverse_transforms = [m[1] for m in mutators]

    METADATA_PROVIDERS = list(itertools.chain(*itertools.chain(t.METADATA_DEPENDENCIES for t in apply_transforms)))
    print(METADATA_PROVIDERS)
    
    class CompositeTaintOperatorProvider(cst.BatchableMetadataProvider[Dict]):
    
        def visit_Comparison(self, node: cst.Param) -> None:
            for p in METADATA_PROVIDERS:
                if not p.MATCHER(node): continue

                meta = self.get_metadata(node)
                if meta is None:
                    p.visit_Comparison(node) # apply the matcher
                    # Mark the node as an equal comparison node
                    self.set_metadata(node, {'tainted': True})
                elif meta.get('tainted') is None:
                    p.visit_Comparison(node)
                else:
                    pass # tainted already
                
    class ApplyCompositeTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (CompositeTaintOperatorProvider,)

        def leave_Comparison(self, vanilla:cst.Name, tainted: cst.Name) -> None:
            for apply in apply_transforms:
                vanilla = tainted = apply().leave_Comparison(vanilla, tainted) 
                
            return tainted

    class ReverseCompositeTransformer(cst.CSTTransformer):
        METADATA_DEPENDENCIES = (CompositeTaintOperatorProvider,)

        def leave_Comparison(self, vanilla:cst.Name, tainted: cst.Name) -> None:
            for reverse in reverse_transforms:
                vanilla = tainted = reverse().leave_Comparison(vanilla, tainted) 
                
            return tainted

    return ApplyCompositeTransformer, ReverseCompositeTransformer

In [168]:
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections

str2op = dict([
    ('==', Equal),
    ('>=', GreaterThanEqual),
    ('>', GreaterThan),
    ('<', LessThan),
    ('=<', LessThanEqual),
    ('!=', NotEqual),
    ('not in', NotIn),
    ('in', In),
    ('is', Is),
    ('is not', IsNot),
    ('not', Not),
    ('and', And),
    ('or', Or),
    ('or', Or),
    # ('match', Match),
])

op2str = {v:k for k,v in str2op.items()}

Transform = collections.namedtuple('Transform', ['apply', 'reverse'])
transforms = collections.defaultdict(dict)

for op1 in [ Equal , GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or,]:
    for op2 in [ Equal, GreaterThanEqual, LessThan, GreaterThan, LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or,]:
        if op1 is op2: continue
        apply, reverse = gen_transfomers(op1, op2)
        transforms[op2str[op1]][op2str[op2]] = Transform(apply, reverse)


In [169]:
len(transforms)

13

In [170]:
import random
random.choice(transforms)

{}

In [163]:
sequence_of_transforms_to_apply = (
    "x >= z and x == y" ,
    [transforms['==']['>='],
     transforms['and']['or']],
    "x >= z or x >= y" ,
)

# pre_wrapper = cst.MetadataWrapper(module)
# change_equal_to_op1, _ = gen_transfomers(Equal, op1)
# new_code = pre_wrapper.visit(change_equal_to_op1)
# print(new_code.code)
# print(op1.__name__, '--->', op2.__name__)
# print('  :', code, )
# print('->:',bugged.code, )
# vanilla = wrapper.visit(reverse())
# print('<-:', vanilla.code)
# print( '---')

code, TLIST, result =  sequence_of_transforms_to_apply

module = cst.parse_module(code)
wrapper = cst.MetadataWrapper(module)
apply, reverse = composite_factory(TLIST)
bugged = wrapper.visit(apply())
apply

TypeError: descriptor '__repr__' of 'object' object needs an argument

In [99]:
# printdiff(cst.parse_expression("1 == 2").code, cst.parse_express('1 >= 2').code)