In [1]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget
from libcst import matchers
from typing import  Union, Tuple, Type, Dict 
import re
from libcst import Equal, GreaterThanEqual
from libcst import matchers as m


Docs: 

https://libcst.readthedocs.io/en/latest/metadata.html#position-metadata

https://libcst.readthedocs.io/en/latest/_modules/libcst/metadata/position_provider.html#PositionProvider

https://libcst.readthedocs.io/en/latest/_modules/libcst/metadata/position_provider.html#WhitespaceInclusivePositionProvidingCodegenState


https://libcst.readthedocs.io/en/latest/parser.html

In [2]:
# Use difflib to show the changes to verify type annotations were added as expected.
import difflib

def printdiff(original_node, updated_node):
    return (
        "".join(
            difflib.unified_diff(original_node.splitlines(1), updated_node.splitlines(1))
        )
    )

In [3]:
import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.metadata import BatchableMetadataProvider, PositionProvider

def gen_context_transfomer(op1, op2):
    class IsComparisonProvider(cst.BatchableMetadataProvider[Dict]):
        def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> None:
            # Mark the node as an equal comparison node
            target = op2() if node.operator.__class__ == op1 else node.operator.__class__
            self.set_metadata(node, dict(comparison=node, 
                                    original=node.operator.__class__,target=target,author=self.__class__.__name__ ))
                         
 

    class ApplyTransformer(ContextAwareTransformer):
        METADATA_DEPENDENCIES = (IsComparisonProvider,PositionProvider, )
        def __init__(
            self,
            context: CodemodContext):
            super().__init__(context)

    
        def leave_ComparisonTarget(self, original_node:cst.ComparisonTarget, updated_node: cst.ComparisonTarget) -> None:
            if not any(original_node.deep_equals(x) for x in self.context.scratch.keys()): # TODO this check should be done based on position
                print("not in scratch")
                meta = self.get_metadata(IsComparisonProvider, original_node)
                meta_pos = self.get_metadata(PositionProvider, original_node)
                if meta['original'] != meta['target']:
                    updated_node = meta['comparison'].with_changes(operator=meta['target']) # OP2
                    
                    self.context.scratch[updated_node] = {"modified": True, "author": self.__class__.__name__, "original_position": meta_pos} # TODO here the new position of the node should be computed instead of using the original one
                    print("added to scratch",meta['original']().__class__,meta['target'].__class__,meta_pos)
            else:
                already_modified=  [self.context.scratch.get(x) for x in self.context.scratch.keys() if original_node.deep_equals(x)]
                print("already in scratch",[(x["original_position"].start, x["original_position"].end) for x in already_modified])        
            return updated_node
            
        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__
    

    return ApplyTransformer

In [4]:

from libcst.codemod import CodemodContext, Codemod
from libcst.metadata import MetadataWrapper
from libcst import Equal, GreaterThanEqual, GreaterThan

class Bugger(Codemod):  
    def __init__(self, op1, op2, op3):
        self.op1 = op1
        self.op2 = op2
        self.op3 = op3
        self.apply_transformer = gen_context_transfomer(op1, op2)
        self.apply_transformer2 = gen_context_transfomer(op2, op3)
        
        self.context = CodemodContext()
        #init parents
        Codemod.__init__(self,self.context)

    def updated_scratch_position(self) -> None :
        #updated position should be 
        # TODO if we are not updating 
        pass

    def transform_module_impl(self, tree: Module) -> Module:
        # Apply the ApplyTransformer to the CST
        tainted=self.apply_transformer(self.context).transform_module(tree) 
        # knows position of originals and representation of changed nodes (identity is useless) can get with x["position"] for x in self.context.scratch.values()
        # can use tainted.code_for_node(x) for x in self.context.scratch.keys() to compute length of changed nodes
        # the start position of the earliest node is guaranteed to be the same as the start position of the original node because no nodes have been changed before it
        # we can iterate through the nodes in order of position and update an offset based on delta length of the changed nodes
        tainted= self.apply_transformer2(self.context).transform_module(tainted)
        # tainted= self.apply_transformer2(self.context).transform_module(tainted_1)
        return tainted


In [5]:
#let's generate a visitor that computes the length of a single node , maybe there are different approaches
    
class LengthVisitor(cst.CSTVisitor):
    METADATA_DEPENDENCIES = (PositionProvider,)
    def __init__(self):
        self.positions = []
    def on_leave(self, node, updated_node):
        self.positions.append(self.get_metadata(PositionProvider, node))
    def get_positions(self):
        return self.positions    


In [6]:
## Test
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections

str2op = dict([
    ('==', Equal),
    ('>=', GreaterThanEqual),
    ('>', GreaterThan),
    ('<', LessThan),
    ('=<', LessThanEqual),
    ('!=', NotEqual),
    ('not in', NotIn),
    ('in', In),
    ('is', Is),
    ('is not', IsNot),
    ('not', Not),
    ('and', And),
    ('or', Or),
    ('or', Or),
    # ('match', Match),
])

bugger = Bugger(str2op['=='], str2op['>'],str2op['>='])

# Get the script as a string
script = "x == 1 + 2 != 3 + 4 > 3"

# Parse the script into a CST
module = cst.parse_module(script)
# first applies == to > then > to >= , the > introduced by the first step should not be modified
tainted = bugger.transform_module(module)
print(tainted.code)
# print(bugger.context.scratch)


not in scratch
added to scratch <class 'libcst._nodes.op.Equal'> <class 'libcst._nodes.op.GreaterThan'> CodeRange(start=CodePosition(line=1, column=1), end=CodePosition(line=1, column=10))
not in scratch
not in scratch
already in scratch [(CodePosition(line=1, column=1), CodePosition(line=1, column=10))]
not in scratch
not in scratch
added to scratch <class 'libcst._nodes.op.GreaterThan'> <class 'libcst._nodes.op.GreaterThanEqual'> CodeRange(start=CodePosition(line=1, column=18), end=CodePosition(line=1, column=22))
x > 1 + 2 != 3 + 4 >= 3


In [7]:
#sort changed nodes by starting node
scratch = bugger.context.scratch
dict(sorted(scratch.items(), key=lambda item: item[1]["original_position"].start.column))

{ComparisonTarget(
     operator=GreaterThan(
         whitespace_before=SimpleWhitespace(
             value=' ',
         ),
         whitespace_after=SimpleWhitespace(
             value=' ',
         ),
     ),
     comparator=BinaryOperation(
         left=Integer(
             value='1',
             lpar=[],
             rpar=[],
         ),
         operator=Add(
             whitespace_before=SimpleWhitespace(
                 value=' ',
             ),
             whitespace_after=SimpleWhitespace(
                 value=' ',
             ),
         ),
         right=Integer(
             value='2',
             lpar=[],
             rpar=[],
         ),
         lpar=[],
         rpar=[],
     ),
 ): {'modified': True,
  'author': 'ApplyTransformer',
  'original_position': CodeRange(start=CodePosition(line=1, column=1), end=CodePosition(line=1, column=10))},
 ComparisonTarget(
     operator=GreaterThanEqual(
         whitespace_before=SimpleWhitespace(
             value='