In [231]:
from extract import Extractor
import inspect as ins
import ast

In [238]:
#from cowait import task
#path = ins.getfile(task)
path = ins.getfile(Extractor)

file = open(path)
code = file.read()
file.close()

tree = ast.parse(code)
print(code)
print(ast.dump(tree))

from os import listdir
from os.path import isfile, isdir
import numpy as np
from tqdm import tqdm

class Extractor:

	def __init__(self, fileType: str):
		self.fileType = fileType

	def find_files(self, start: str):
		files = []

		for fd in listdir(start):
			path = f'{start}/{fd}'

			if isfile(path):
				ext = f'.{self.fileType}'
				if fd[-len(ext):] == ext: files.append(path)
			else:
				f = self.find_files(path)
				files += f if f else []
	
		return files
	
	def get_content(self, file: str, skipError: bool=True):
		if not isfile(file):
			raise ValueError(f'{file} is not a file')

		try:
			with open(file) as f:
				return f.read()
		except Exception as e:
			if skipError: print(f'Parsing error, skipping file: {file}, {e}')
			else: raise e
	
	def extract(self, dir: str):
		if not isdir(dir):
			raise ValueError(f'{dir} is not a directory')
		
		files = self.find_files(dir)
		return np.array([self.get_content(f) for f in tqdm(files)])

		




Module(body=[ImportFrom(module='os'

In [240]:
def get_return(node):
    class ReturnVisitor(ast.NodeVisitor):
        def __init__(self):
            self.r = None
        def visit_Return(self, node):
            self.r = node
    
    v = ReturnVisitor()
    v.visit(node)
    return v.r

class Visitor(ast.NodeVisitor):
    def __init__(self):
        self.stats = {
            'import': [],
            'from': [],
            'func': [],
            'class': [],
        }
        
    def generic_visit(self, node):
        #print(type(node).__name__)
        super().generic_visit(node)
    
    def visit_Import(self, node):
        for alias in node.names:
            self.stats["import"].append(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        self.stats["from"].append((node.module, [a.name for a in node.names]))
        self.generic_visit(node)
        
    def visit_FunctionDef(self, node):
        args = [*node.args.posonlyargs, *node.args.args, *node.args.kwonlyargs]
        
        retr = get_return(node)
        
        retr = retr.value if retr else None
        
        if type(retr) == ast.Name:
            retr = retr.id
        
        funcdef = (
            node.name,
            [*map(lambda a: a.arg, args)],
            (bool(node.args.kwarg), bool(node.args.vararg)),
            retr
        )
        
        self.stats['func'].append(funcdef)
        self.generic_visit(node)
        
    def visit_AsyncFunctionDef(self, node):
        self.visit_FunctionDef(node)
        
    def visit_ClassDef(self, node):
        args = [*node.keywords]
        classdef = (node.name, [*map(lambda a: a.arg, args)])
        self.stats['class'].append(classdef)
        self.generic_visit(node)

    def report(self):
        for stat, value in self.stats.items():
            print(f'--- {stat} ---')
            if stat == 'func':
                for func in value:
                    print(f'{func[0]}({func[1]}, {sum(func[2]) * "*"}) -> {func[3]}')
            else:
                print(value)
            print()

In [241]:
visitor = Visitor()
visitor.visit(tree)
visitor.report()

--- import ---
['numpy']

--- from ---
[('os', ['listdir']), ('os.path', ['isfile', 'isdir']), ('tqdm', ['tqdm'])]

--- func ---
__init__(['self', 'fileType'], ) -> None
find_files(['self', 'start'], ) -> files
get_content(['self', 'file', 'skipError'], ) -> <_ast.Call object at 0x7fc9d3ecac10>
extract(['self', 'dir'], ) -> <_ast.Call object at 0x7fc9d3e9afd0>

--- class ---
[('Extractor', [])]



Summarizing file ASTs might be done with the approach above. However, we are more interested in finding a context for a given piece of code in order to put it through an encoder. Let's try to take a specific function and summarize all the parent nodes of the syntax tree until we encounter the `__init__.py` file which signals that we've reached the top of the library.

In [246]:
def get_obj_tree(obj):
    code = ins.getsource(obj)
    tree = ast.parse(code)
    return tree

In [283]:
def visitor_from_node(node):
    visitor = Visitor()
    visitor.visit(node)
    return visitor

In [None]:
class Visitor(ast.NodeVisitor):
    def __init__(self):
        self.stats = {
            'import': [],
            'from': [],
            'func': [],
            'class': [],
        }
        
    def visit_Import(self, node):
        for alias in node.names:
            self.stats['import'].append(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        self.stats["from"].append((node.module, [a.name for a in node.names]))
        self.generic_visit(node)
        
    def visit_FunctionDef(self, node):
        self.stats['func'].append(node.name)
        self.generic_visit(node)
        
    def visit_AsyncFunctionDef(self, node):
        self.visit_FunctionDef(node)
        
    def visit_ClassDef(self, node):
        self.stats['class'].append(node.name)
        self.generic_visit(node)

    def report(self):
        for stat, value in self.stats.items():
            print(f'--- {stat} ---')
            print(value)

class ContextDeducer:
    def __init__(self, name):
        exec(f'import {name}\nref = {name}')
        self.entry = ContextDeducer.get_obj_tree(ref)
        self.name = name
    
    @staticmethod
    def get_tree(obj):
        # Takes Python reference, returns ast of the contents
        code = ins.getsource(obj)
        return ast.parse(code)
    
    @staticmethod
    def get_stats(node):
        # Visits all nodes of tree, returns visitor stats
        visitor = Visitor()
        visitor.visit(node)
        return visitor.stats
    
    def traverse(self, node = self.entry):
        stats = ContextDeducer.get_stats(node)
        
        #  TODO: filter import statements outside project code
        
        refs = []
            
        # import statements
        for module in stats['import']:
            exec(f'import {module}')
            exec(f'ref = {module}')
            refs.append(ref)

        # from import statements
        for entry in stats['from']:
            module, names = entry
            form = lambda a,b: f'{a}, {b}'
            exec(f'from {self.name}.{module} import {reduce(form, names)}')

            for name in names:
                exec(f'ref = {name}')
                refs.append(ref)
        
        # function definitions
        for func in stats['func']:
            pass
        
        # class definitions
        for cls in stats['class']:
            
            
        return refs
        

In [284]:
import cowait

tree = get_obj_tree(cowait)
print(ast.dump(tree), '\n')

visitor = visitor_from_node(tree)
visitor.report()

Module(body=[ImportFrom(module='version', names=[alias(name='version', asname=None)], level=1), ImportFrom(module='tasks', names=[alias(name='Task', asname=None), alias(name='task', asname=None), alias(name='rpc', asname=None), alias(name='join', asname=None), alias(name='wait', asname=None), alias(name='sleep', asname=None), alias(name='spawn', asname=None), alias(name='input', asname=None), alias(name='exit', asname=None)], level=1)], type_ignores=[]) 

--- import ---
[]

--- from ---
[('version', ['version']), ('tasks', ['Task', 'task', 'rpc', 'join', 'wait', 'sleep', 'spawn', 'input', 'exit'])]

--- func ---

--- class ---
[]



We can see the references to the imported subroutines. If we follow these, and pick up any function definitions along the way, we should be able to generate a context ast for every function.

In [262]:
from functools import reduce

statements = []

for module in visitor.stats['from']:
    statement = f'from cowait.{module[0]} import {reduce(lambda a,b: f"{a}, {b}", module[1])}'
    statements.append(statement)

for module in visitor.stats['import']:
    statement = f'import {module}'
    statements.append(statement)

In [266]:
for s in statements:
    exec(statement)

In [268]:
print(statements)

['from cowait.version import version', 'from cowait.tasks import Task, task, rpc, join, wait, sleep, spawn, input, exit']


In [269]:
print(version)
print(spawn)

0.4.31
<function spawn at 0x7fc9d4322280>


In [291]:
refs = []

for module in visitor.stats['from']:
    for obj in module[1]:
        exec(f'ref = {obj}')
        refs.append(ref)

In [310]:
# Code borrowed from https://github.com/pombredanne/python-ast-visualizer/blob/master/astvisualizer.py

import ast
import graphviz as gv
import subprocess
import numbers
import re
from uuid import uuid4 as uuid
import optparse
import sys

def transform_ast(code_ast):
    if isinstance(code_ast, ast.AST):
        node = {to_camelcase(k): transform_ast(getattr(code_ast, k)) for k in code_ast._fields}
        node['node_type'] = to_camelcase(code_ast.__class__.__name__)
        return node
    elif isinstance(code_ast, list):
        return [transform_ast(el) for el in code_ast]
    else:
        return code_ast


def to_camelcase(string):
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', string).lower()


class GraphRenderer:
    """
    this class is capable of rendering data structures consisting of
    dicts and lists as a graph using graphviz
    """

    graphattrs = {
        'labelloc': 't',
        'fontcolor': 'white',
        'bgcolor': '#333333',
        'margin': '0',
    }

    nodeattrs = {
        'color': 'white',
        'fontcolor': 'white',
        'style': 'filled',
        'fillcolor': '#006699',
    }

    edgeattrs = {
        'color': 'white',
        'fontcolor': 'white',
    }

    _graph = None
    _rendered_nodes = None


    @staticmethod
    def _escape_dot_label(str):
        return str.replace("\\", "\\\\").replace("|", "\\|").replace("<", "\\<").replace(">", "\\>")


    def _render_node(self, node):
        if isinstance(node, (str, numbers.Number)) or node is None:
            node_id = uuid()
        else:
            node_id = id(node)
        node_id = str(node_id)

        if node_id not in self._rendered_nodes:
            self._rendered_nodes.add(node_id)
            if isinstance(node, dict):
                self._render_dict(node, node_id)
            elif isinstance(node, list):
                self._render_list(node, node_id)
            else:
                self._graph.node(node_id, label=self._escape_dot_label(str(node)))

        return node_id


    def _render_dict(self, node, node_id):
        self._graph.node(node_id, label=node.get("node_type", "[dict]"))
        for key, value in node.items():
            if key == "node_type":
                continue
            child_node_id = self._render_node(value)
            self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(key))


    def _render_list(self, node, node_id):
        self._graph.node(node_id, label="[list]")
        for idx, value in enumerate(node):
            child_node_id = self._render_node(value)
            self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(str(idx)))


    def render(self, data, *, label=None):
        # create the graph
        graphattrs = self.graphattrs.copy()
        if label is not None:
            graphattrs['label'] = self._escape_dot_label(label)
        graph = gv.Digraph(graph_attr = graphattrs, node_attr = self.nodeattrs, edge_attr = self.edgeattrs)

        # recursively draw all the nodes and edges
        self._graph = graph
        self._rendered_nodes = set()
        self._render_node(data)
        self._graph = None
        self._rendered_nodes = None

        # display the graph
        graph.format = "pdf"
        graph.view()

In [328]:
import django
tree = get_obj_tree(refs[2])
print(ast.dump(tree))
transformed_ast = transform_ast(tree)
renderer = GraphRenderer()
renderer.render(transformed_ast)

Module(body=[FunctionDef(name='task', args=arguments(posonlyargs=[], args=[arg(arg='func', annotation=None, type_comment=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Expr(value=Constant(value=' Wraps a task function in a Task class ', kind=None)), If(test=Compare(left=Call(func=Attribute(value=Subscript(value=Attribute(value=Name(id='func', ctx=Load()), attr='__name__', ctx=Load()), slice=Index(value=Constant(value=0, kind=None)), ctx=Load()), attr='lower', ctx=Load()), args=[], keywords=[]), ops=[Eq()], comparators=[Subscript(value=Attribute(value=Name(id='func', ctx=Load()), attr='__name__', ctx=Load()), slice=Index(value=Constant(value=0, kind=None)), ctx=Load())]), body=[Raise(exc=Call(func=Name(id='NameError', ctx=Load()), args=[JoinedStr(values=[Constant(value='Task names must start with an uppercase character, found ', kind=None), FormattedValue(value=Attribute(value=Name(id='func', ctx=Load()), attr='__name__', ctx=Load()), conversion=-1, 

In [None]:
for ref in refs:
    try:
        tree = get_obj_tree(ref)
        visitor = visitor_from_node(tree)
        print(ref)
        #visitor.report()
        print(ast.dump(tree))
        print()
    except TypeError:
        pass