In [26]:
from libcst import parse_module, Module, Expr, Pass, Comment, CSTTransformer, Comparison, CSTNode, ComparisonTarget
from libcst import matchers
from typing import  Union, Tuple, Type, Dict 
import re
from libcst import Equal, GreaterThanEqual
from libcst import matchers as m


Docs: 

https://libcst.readthedocs.io/en/latest/metadata.html#position-metadata

https://libcst.readthedocs.io/en/latest/_modules/libcst/metadata/position_provider.html#PositionProvider

https://libcst.readthedocs.io/en/latest/_modules/libcst/metadata/position_provider.html#WhitespaceInclusivePositionProvidingCodegenState


https://libcst.readthedocs.io/en/latest/parser.html

In [27]:

from libcst.codemod import CodemodContext, Codemod
from libcst.metadata import MetadataWrapper
from libcst import Equal, GreaterThanEqual, GreaterThan
from typing import List
from copy import deepcopy
import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer, ContextAwareVisitor
from libcst.metadata import BatchableMetadataProvider, PositionProvider, CodePosition, CodeRange
import libcst.matchers as m
import uuid

class PositionContextUpdater(ContextAwareTransformer):
    METADATA_DEPENDENCIES = (PositionProvider,)
    def __init__(self, context: CodemodContext) -> None:
        self.context = context
        #init parent
        super().__init__(self.context)
    def on_visit(self, node: "CSTNode") -> bool:
        return True 
    def update_positions(self, meta_pos: CodePosition) -> None:
            already_modified  = [x for x in self.context.scratch.values() if meta_pos.start== x["original_position"].start]
            # print("already in scratch",[(x["original_position"].start, x["original_position"].end) for x in already_modified]) 
            # print("modified by",[(x["author"]) for x in already_modified])
            modified_keys = [x for x in self.context.scratch.keys() if meta_pos.start== self.context.scratch.get(x)["original_position"].start]
            key = modified_keys[0]
            #compute the delta between the end of the original node and the new one
            delta = meta_pos.end.column - self.context.scratch[key]["original_position"].end.column
            #update the scratch with the new position of the node
            self.context.scratch[key]["original_position"] = meta_pos
            #update the column position of all the nodes in the scratch that are on the same line
            for (k,v) in zip(list(self.context.scratch.keys()),list(self.context.scratch.values())):
                if v["original_position"].start.line == meta_pos.start.line and k.column > key.column:
                    new_start = CodePosition(line=v["original_position"].start.line,column=v["original_position"].start.column+delta)
                    new_end = CodePosition(line=v["original_position"].end.line,column=v["original_position"].end.column+delta)
                    v["original_position"]=CodeRange(start= new_start,end = new_end)
                    self.context.scratch[new_start] = self.context.scratch.pop(k)    
            # print("the delta is",delta)   
    def on_leave(self, node, updated_node):
        meta_pos=self.get_metadata(PositionProvider, node)
        meta_scratch = self.context.scratch.get(meta_pos.start,None)
        if meta_scratch and node.deep_equals(meta_scratch["updated_node"]):
            self.update_positions(meta_pos)
        return updated_node    
    def get_positions(self):
        return self.positions    

        
class Bugger(Codemod):  
    def __init__(self, transformers: List[ContextAwareTransformer]) -> None:
        
        self.context = CodemodContext()
        Codemod.__init__(self,self.context)
        self.transformers = [transformer(self.context) for transformer in transformers]
        self.position_updater = PositionContextUpdater(self.context)
        #the context scratchpad has an entry ["modfied_nodes"] indexed by the start position of the modified_nodes 
        self.debug = False
        self.debug_steps = []
    def apply(self, tree:Module,debug:bool=False) ->Module:
        self.debug=debug
        return self.transform_module(tree)
    def transform_module_impl(self, tree: Module) -> Module:
        tainted = tree
        for transformer in self.transformers:
            tainted = transformer.mutate(tainted,self.debug)
            tainted = self.position_updater.transform_module(tainted) 
            if self.debug:
                self.debug_steps.append(tainted)

        return tainted



# ComparisonTransformer

Takes two comparison operators as input and swaps all instances of the first with the second one

In [28]:
def gen_context_transfomer(op1, op2):
    class ComparisonTransformer(ContextAwareTransformer):
        METADATA_DEPENDENCIES = (PositionProvider, )
        def __init__(
            self,
            context: CodemodContext):
            self.op1 = op1
            self.op2 = op2
            super().__init__(context)
            self.id = f"{self.__class__.__name__}-{uuid.uuid4().hex[:4]}"
            self.reverse=False
        def transform_module_impl(self, tree: cst.Module) -> cst.Module:
            return tree.visit(self)
        def mutate(self, tree: cst.Module,reverse: bool = False) -> cst.Module:
            self.reverse=reverse
            return self.transform_module(tree)
              
        def leave_ComparisonTarget(self, original_node:cst.ComparisonTarget, updated_node: cst.ComparisonTarget) -> None:
            meta_pos = self.get_metadata(PositionProvider, original_node)
            #only updates nodes that are not already in the scratch
            already_modified  = [x for x in self.context.scratch.values() if meta_pos.start== x["original_position"].start]
            if not self.reverse and not already_modified and original_node.operator.__class__ == self.op1: 
       
                #only updates nodes tagged for changes
                # print("adding to scratch",meta_pos.start, meta_pos.end)
                updated_node = original_node.with_changes(operator=self.op2()) # OP2
                self.context.scratch[meta_pos.start] = {
                "modified": True, 
                "original_position": meta_pos,
                "original_operator":original_node.operator.__class__,
                "updated_operator":self.op2,
                "original_node":original_node ,
                "updated_node":updated_node,
                "author":self.id
                } 
            elif self.reverse and already_modified and already_modified[0]["author"] == self.id:
                # print("reverting to old node",meta_pos.start, meta_pos.end)
                old_node= self.context.scratch[meta_pos.start]["original_node"]
                self.context.scratch[meta_pos.start]["updated_node"]=old_node
                updated_node=old_node

            return updated_node
            
        def __repr__(self):
            return super().__repr__(self) + ':' + op1.__name__ +':' + op2.__name__
    
    return ComparisonTransformer

In [29]:
## Test
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections

str2op = dict([
    ('==', Equal),
    ('>=', GreaterThanEqual),
    ('>', GreaterThan),
    ('<', LessThan),
    ('=<', LessThanEqual),
    ('!=', NotEqual),
    ('not in', NotIn),
    ('in', In),
    ('is', Is),
    ('is not', IsNot),
    ('not', Not),
    ('and', And),
    ('or', Or),
    ('or', Or),
    # ('match', Match),
])

op1 = str2op['==']
op2 = str2op['>']
op3 = str2op['>=']
transformers = [gen_context_transfomer(op1, op2), gen_context_transfomer(op2, op3)]
bugger = Bugger(transformers)


# Get the script as a string
script = "x == 1 + 2 == 3 + 2 != 3 + 4 > 3"

# Parse the script into a CST
module = cst.parse_module(script)
# first applies == to > then > to >= , the > introduced by the first step should not be modified
tainted = bugger.apply(module)
clean = bugger.apply(tainted,debug=True)

# print(bugger.context.scratch)
print("original code")
print(module.code)
print("Bugged Code")
print(tainted.code)
print("Debugged Code")
print(clean.code)
print("Now we debug step by step starting again from bugged code")
print(tainted.code)
for (i,step) in enumerate(bugger.debug_steps):
    print("Step",i+1)
    print(step.code)

original code
x == 1 + 2 == 3 + 2 != 3 + 4 > 3
Bugged Code
x > 1 + 2 > 3 + 2 != 3 + 4 >= 3
Debugged Code
x == 1 + 2 == 3 + 2 != 3 + 4 > 3
Now we debug step by step starting again from bugged code
x > 1 + 2 > 3 + 2 != 3 + 4 >= 3
Step 1
x == 1 + 2 == 3 + 2 != 3 + 4 >= 3
Step 2
x == 1 + 2 == 3 + 2 != 3 + 4 > 3


# ForgettingToUpdateVariable

Finds all instances of variable assignments and swaps the right operator with the left one, e.g x = 4 --> x = x

In [30]:
from libcst import AssignTarget
class ForgettingToUpdateVariableTransformer(ContextAwareTransformer):
        METADATA_DEPENDENCIES = (PositionProvider, )
        def __init__(
            self,
            context: CodemodContext):
            super().__init__(context)
            self.id = f"{self.__class__.__name__}-{uuid.uuid4().hex[:4]}"
            self.reverse=False
        def transform_module_impl(self, tree: cst.Module) -> cst.Module:
            return tree.visit(self)
        def mutate(self, tree: cst.Module,reverse: bool = False) -> cst.Module:
            self.reverse=reverse
            return self.transform_module(tree)
              
        def leave_Assign(self, original_node:cst.Assign, updated_node: cst.Assign) -> None:
            meta_pos = self.get_metadata(PositionProvider, original_node)
            #only updates nodes that are not already in the scratch
            already_modified  = [x for x in self.context.scratch.values() if meta_pos.start== x["original_position"].start]
            if not self.reverse and not already_modified: 
                # var_name = original_node.value
                # old_target = original_node.targets
                # new_target = AssignTarget(target=var_name)
                # print("pred updating with value",original_node.value)
                # print("pred updating with target",original_node.targets[0].target)

                updated_node = original_node.with_changes(value=original_node.targets[0].target)

                self.context.scratch[meta_pos.start] = {
                "modified": True, 
                "original_position": meta_pos,
                "original_value":original_node.value,
                "updated_value":original_node.targets[0].target,
                "original_node":original_node ,
                "updated_node":updated_node,
                "author":self.id
                } 
            elif self.reverse and already_modified and already_modified[0]["author"] == self.id:
                # print("reverting to old node",meta_pos.start, meta_pos.end)
                old_node= self.context.scratch[meta_pos.start]["original_node"]
                self.context.scratch[meta_pos.start]["updated_node"]=old_node
                updated_node=old_node
            return updated_node


In [31]:
## Test
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections


transformers = [ForgettingToUpdateVariableTransformer]
bugger = Bugger(transformers)


# Get the script as a string
script = "x,y = 1 + 2, 3+4 \nx = x + 1\nx = x + 1"

# Parse the script into a CST
module = cst.parse_module(script)
# first applies == to > then > to >= , the > introduced by the first step should not be modified
tainted = bugger.apply(module)
clean = bugger.apply(tainted,debug=True)

# print(bugger.context.scratch)
print("old_code")
print(module.code)
print("tainted_code")
print(tainted.code)
print("clean_code")
print(clean.code)

old_code
x,y = 1 + 2, 3+4 
x = x + 1
x = x + 1
tainted_code
x,y = x,y 
x = x
x = x
clean_code
x,y = 1 + 2, 3+4 
x = x + 1
x = x + 1


# InfiniteWhileLoop

The objective of this transformer is to find while loops and make them infinite by adding the condition True

In [32]:
import random
import re

class InfiniteWhileTransformer(ContextAwareTransformer):
    METADATA_DEPENDENCIES = (PositionProvider,)
    def __init__(self, context: CodemodContext):
        super().__init__(context)
        self.id = f"{self.__class__.__name__}-{uuid.uuid4().hex[:4]}"
        self.reverse = False
    def mutate(self, tree: cst.Module,reverse: bool = False) -> cst.Module:
            self.reverse=reverse
            return self.transform_module(tree)
    def leave_While(self, original_node: cst.While, updated_node: cst.While) -> None:
        meta_pos = self.get_metadata(PositionProvider, original_node)
        #only updates nodes that are not already in the scratch
        already_modified  = [x for x in self.context.scratch.values() if meta_pos.start== x["original_position"].start]
        if not self.reverse and not already_modified: 
            updated_node = cst.While(
                test=cst.Name("True"),
                body=original_node.body
            )
            self.context.scratch[meta_pos.start] = {
                "modified": True,
                "original_position": meta_pos,
                "original_node": original_node,
                "updated_node": updated_node,
                "author": self.id
            }
        elif self.reverse and already_modified and already_modified[0]["author"] == self.id:
            original_node = self.context.scratch[meta_pos.start]["original_node"]
            updated_node = original_node
        return updated_node

In [33]:

## Test
from libcst import (Equal, GreaterThanEqual, LessThan, GreaterThan, 
                    LessThanEqual, NotEqual, NotIn, In, Is, IsNot, Not, And, Or, Match)
import collections


transformers = [InfiniteWhileTransformer]
bugger = Bugger(transformers)


# Get the script as a string it should have  while loop and take multiple lines
script = "while x < 10:\n\tprint(x)\n\tx = x + 1"

# Parse the script into a CST
module = cst.parse_module(script)
# first applies == to > then > to >= , the > introduced by the first step should not be modified
tainted = bugger.apply(module)
clean = bugger.apply(tainted,debug=True)

# print(bugger.context.scratch)
print("old_code")
print(module.code)
print("tainted_code")
print(tainted.code)
print("clean_code")
print(clean.code)

old_code
while x < 10:
	print(x)
	x = x + 1
tainted_code
while True:
	print(x)
	x = x + 1
clean_code
while x < 10:
	print(x)
	x = x + 1


In [34]:
#sort changed nodes by starting node
scratch = bugger.context.scratch
dict(sorted(scratch.items(), key=lambda item: item[1]["original_position"].start.column))

{CodePosition(line=1, column=0): {'modified': True,
  'original_position': CodeRange(start=CodePosition(line=1, column=0), end=CodePosition(line=3, column=10)),
  'original_node': While(
      test=Comparison(
          left=Name(
              value='x',
              lpar=[],
              rpar=[],
          ),
          comparisons=[
              ComparisonTarget(
                  operator=LessThan(
                      whitespace_before=SimpleWhitespace(
                          value=' ',
                      ),
                      whitespace_after=SimpleWhitespace(
                          value=' ',
                      ),
                  ),
                  comparator=Integer(
                      value='10',
                      lpar=[],
                      rpar=[],
                  ),
              ),
          ],
          lpar=[],
          rpar=[],
      ),
      body=IndentedBlock(
          body=[
              SimpleStatementLine(
                  body