# Explode Function

In [None]:
import ast
import astor

class FunctionExploder(ast.NodeTransformer):
    def visit_FunctionDef(self, func):
        [docstring], func.body = func.body[:1], func.body[1:]

        # docstring
        parser = doctest.DocTestParser()
        results = parser.parse(docstring.value.s)
        docstring_prefix, docstring_examples = results[0], [result for result in results if isinstance(result, doctest.Example)]
        assign_exprs = [example.source.strip() for example in docstring_examples]
        
        # filter returns
        func.body = [stmt for stmt in func.body if not isinstance(stmt, ast.Return)]
        
        # augment body with docstring
        body = []
        body.append(
            Annotator._make_annotation(
                content=' '.join(substring.capitalize() for substring in func.name.split('_')),
                cell_type='1'
            )
        )
        body.append(Annotator._make_annotation(content=docstring_prefix, cell_type='markdown'))
        body.append(Annotator._make_annotation(content='Example Input', cell_type='1'))
        for assign_expr in assign_exprs:
            tree = ast.parse(assign_expr)
            body.append(tree.body[0])
        body.append(Annotator._make_annotation(content='Body of Function', cell_type='1'))
        for stmt in func.body:
            body.append(stmt)
            
        return ast.Module(body=body)
    
code = '''

def foo(a):
    """This is a docstring
    
    >>> a = 7
    
    """
    for i in range(a):
        print(i)

'''

tree = ast.parse(code)
tree = FunctionExploder().visit(tree)
code = astor.to_source(tree)
print(code)

# Rewrite Syntax Tree

In [None]:
import ast
import astor

class SyntaxRewriter(ast.NodeTransformer):
    def visit_For(self, forr):
        """
        for i in iterable:
            <body>
            
        becomes
        
        p = iter(iterable)
        while True:
            try:
                i = next(p)
            except StopIteration:
                break
            <body>
            
        """
        # p = iter(iterable)
        assign_iter = ast.Assign(
            targets=[ast.Name(id='p', ctx=ast.Store())],
            value=ast.Call(
                func=ast.Name(id='iter', ctx=ast.Load()),
                args=[forr.iter],
                keywords=[]
            )
        )
        
        # i = next(iter(iterable))
        assign_next = ast.Assign(
            targets=[forr.target],
            value=ast.Call(
                func=ast.Name(id='next', ctx=ast.Load()),
                args=[ast.Name(id='p', ctx=ast.Load())],
                keywords=[]
            )
        )

        # try:
        #     p = iter(iterable)
        # except:
        #     break
        try_node = ast.Try(
            body=[assign_next],
            handlers=[ast.ExceptHandler(type=ast.Name(id='StopIteration', ctx=ast.Load()), name=None, body=[ast.Break()])],
            orelse=[],
            finalbody=[]
        )

        # while True:
        #     try:
        #         p = iter(iterable)
        #     except:
        #        break
        while_node = ast.While(
            test=ast.NameConstant(value=True),
            body=[try_node] + forr.body,
            orelse=[]
        )

        return ast.Module(body=[assign_iter, while_node])
    
tree = ast.parse(code)
tree = SyntaxRewriter().visit(tree)
code = astor.to_source(tree)
print(code)

# Annotate Expressions

In [None]:
import ast
import astor

class Annotator(ast.NodeTransformer):
    context_nodes = [ast.If, ast.While, ast.Try]
    
    @staticmethod
    def _make_annotation(node=None, content=None, emacs_buffer='outside', cell_type='code'):
        """Return a ast.Expr that looks like
        
        epc_client('make-code-cell-and-eval', [content, emacs_buffer, cell_type])
        
        """
        func = ast.Attribute(
            value=ast.Name(id='epc_client', ctx=ast.Load()),
            attr='call_sync',
            ctx=ast.Load()
        )
        elisp_funcname = 'make-code-cell-and-eval'
        content = astor.to_source(node).strip() if node else content
        args = [
            ast.Str(s=elisp_funcname),
            ast.List(elts=[ast.Str(s=content), ast.Str(s=emacs_buffer), ast.Str(s=cell_type)], ctx=ast.Load())
        ]
        call = ast.Call(func=func, args=args, keywords=[])
        return ast.Expr(call)

    def _annotate_nodes(self, nodes):
        """Make annotation on the nodes.
        
        If the node has a context then don't annotate it normally.
        Rather recursively call `visit()` on it.
        
        """
        body = []
        for node in nodes:
            if any(isinstance(node, node_type) for node_type in Annotator.context_nodes):
                node = self.visit(node)
                body.append(node)
            else:
                body.append(node)
                body.append(Annotator._make_annotation(node))
        return body

    def visit_If(self, iff):
        return ast.copy_location(
            ast.Module(body=[
                Annotator._make_annotation(content=f'if {astor.to_source(iff.test).strip()} ...', cell_type='2'),
                Annotator._make_annotation(iff.test),
                ast.If(
                    test=iff.test,
                    body=self._annotate_nodes(iff.body),
                    orelse=self._annotate_nodes(iff.orelse)
                )
            ]),
            iff
        )
    
    def visit_While(self, whilst):
        return ast.copy_location(
            ast.Module(body=[
                Annotator._make_annotation(content=f'while {astor.to_source(whilst.test).strip()} ...', cell_type='2'),
                Annotator._make_annotation(whilst.test),
                ast.While(
                    test=whilst.test,
                    body=self._annotate_nodes(whilst.body),
                    orelse=self._annotate_nodes(whilst.orelse),    
                )
            ]),
            whilst
        )
    
    def visit_Try(self, try_):
        handlers = []
        for handler in try_.handlers:
            handlers.append(
                ast.ExceptHandler(
                    type=handler.type,
                    name=None,
                    body=self._annotate_nodes(handler.body)
                )
            )
        return ast.copy_location(
            ast.Try(
                body=self._annotate_nodes(tryme.body),
                handlers=handlers,
                orelse=self._annotate_nodes(try_.orelse),
                finalbody=self._annotate_nodes(try_.finalbody)
            ),
            try_
        )
    
    def visit_Assign(self, assign):
        """Append the targets to the assign code string
        
        Do the same thing as `generic_visit()` otherwise.
        
        """
        assign_content, targets_content = astor.to_source(assign), astor.to_source(assign.targets[0])
        content = assign_content + targets_content.strip()
        annotation = Annotator._make_annotation(content=content)
        return ast.copy_location(ast.Module(body=[assign, annotation]), assign)
    
    def visit_Call(self, call):
        """Skip annotations if they are already here
        
        They would get here if FunctionExploder() was called already.
        
        """
        return self.generic_visit(call) if not call.func.value.id == 'epc_client' else call

    def visit_Expr(self, expr):
        return self.visit(expr.value)
        
    def generic_visit(self, node):
        annotation = Annotator._make_annotation(node)
        return ast.copy_location(ast.Module(body=[node, annotation]), node)

tree = ast.parse(code)
tree.body = [Annotator().visit(node) for node in tree.body]
code = astor.to_source(tree)
print(code)

# Whole Pipeline

In [None]:
code = """

if foo in bar:
    width, height = scene_image.size
    for i, obj in enumerate(mod_vec_payload['objects']):
        print(1)
        print(2)

# Cropping and processing the object patches from the scene image
object_arrays, object_imgs = [], []
for i, obj in tqdm(enumerate(mod_vec_payload['objects'])):
    print(3)
    print(4)

with graph.as_default():
    eprint('GOT TF GRAPH AND VECTORIZING')
    all_object_vectors = predictF2V(xception_ftr_xtrct, object_arrays)

"""

tree = ast.parse(code)
tree = FunctionExploder().visit(tree)
code = astor.to_source(tree)
print(code)

In [None]:
tree = ast.parse(code)
tree.body = [SyntaxRewriter().visit(node) for node in tree.body]
code = astor.to_source(tree)
print(code)

In [None]:
tree = ast.parse(code)
tree.body = [Annotator().visit(node) for node in tree.body]
code = astor.to_source(tree)
print(code)