In [65]:
# coding: utf-8
import sys
import re
import string
import ast
from ast import AST
import tokenize
import token
from numbers import Number
import json
import io

In [66]:
def posFromText(text, textPos):
    snippet = text[:textPos+1]
    lines = snippet.split("\n")
    ln = len(lines)
    ch = len(lines[-1])
    return {'line': ln, 'ch': ch}
    

In [67]:
def regexEnd(text, start, parentEnd, nodeStart):
    endLineno = None
    endCh = None
    first = text[start]
    m = None
    
    if first in parens:
        # find outer parens
        outer = re.compile("\((.*)\)")
        m = outer.search(text[start:parentEnd+1])
    elif first in sqParens:
        # find outer parens
        outer = re.compile("\[(.*)\]")
        m = outer.search(text[start:parentEnd+1])
    elif first in brackets:
        # find outer parens
        outer = re.compile("\{(.*)\}")
        m = outer.search(text[start:parentEnd+1])
    else: return None
    
    lines = m.group(0).split("\n")
    end = len(m.group(0)) + start
    endLineno = nodeStart['line'] + len(lines)
    endCh = len(lines[-1]) - 1
    if (endLineno == nodeStart['line']): endCh += nodeStart['ch']
    myEnd = {'line': endLineno - 1, 'ch': endCh}
    return (end, myEnd,  m.group(0))

In [68]:
def findNodeStart(node):
    if hasattr(node, 'lineno'):
        return {'line': node.lineno, 'ch': node.col_offset}
    elif  type(node).__name__ == "Module":
        return {'line': 1, 'ch': 0}
    else: # must be some kind of wrapper node
        children = ast.iter_child_nodes(node)
        firstChild = next(children, None)
        if firstChild is None: return None
        return findNodeStart(firstChild)

In [69]:
def findLiteralEnd(snippet, nodeStart): # what to do about multiline literal???
    bracketed = regexEnd(snippet, 0, len(snippet), nodeStart)
    if(bracketed):
        i, nodeEnd, punct = bracketed
        literal = [{'syntok': str(p)} for p in list(punct)]
    else:
        end = 0
        for i, character in enumerate(snippet):
            if character in punctuation:
                end = i - 1
                break
        literal = snippet[0:end+1]
        ch = len(literal) + nodeStart['ch']
        line = nodeStart['line']
        nodeEnd = {'line': line, 'ch': ch}
        
    print("literal |"+str(literal)+"|")
    return (i, nodeEnd, literal)

In [157]:
def OLDfindNextChild(children, itr):
    banned = ["Store", "Load"]
    if(itr + 1 < len(children)):
        child = children[itr + 1]
        if(type(child).__name__ not in banned):
            return child, itr + 1
        else:
            return findNextChild(children, itr + 1)
    else:
        return None, itr + 1

In [107]:
def captureComment(text, textStart, textEnd):
    line = text[textStart:textEnd]
    line = line[:line.find("\n")]
    return textStart + len(line), line

In [117]:
def captureStuff(text, end, nodeItem, puncStop = "", puncNL = False):
    content = []
    end, item = visit(nodeItem, text, end, len(text), None)
    content.append(item)
    # get any symbols like commas and spaces
    end, symbols = getPunctuationBetween(text, end, puncStop, puncNL)
    content += symbols
    return end, content

In [171]:
'''
mod = Module(stmt* body)
        | Interactive(stmt* body)
        | Expression(expr body)

        -- not really an actual node but useful in Jython's typesystem.
        | Suite(stmt* body)
'''
def visitModule(node, text, textStart, textEnd):
    myType = type(node).__name__
    myContent = []
    myStart = posFromText(text, textStart)
    end = textStart
    # get any symbols like commas and spaces
    end, symbols = getPunctuationBetween(text, end, "", True)
    myContent += symbols
    print("Start:", myContent)
    
    if(isinstance(node.body, list)):
        for stmt in node.body:
            end, stuff = captureStuff(text, end, stmt, "", True)
            myContent += stuff
            print(myType+" AFTER ", stmt, myContent)
    else:
        end, expr = visit(node.body, text, end, textEnd, None)
        
    # get any symbols like commas and spaces
    end, symbols = getPunctuationBetween(text, end, "", True)
    myContent += symbols
    print("END:", myContent)
    
    myEnd = posFromText(text, end)
    me = {'type': myType, 'start': myStart, 'end': myEnd, 'content': myContent}
    print("MADE:", ast.dump(node, True, False), "\n",me,"\n")
    return end, me

In [160]:
def stmtOrExpr(node):
    myType = type(node).__name__
    myContent = []
    myStart = {'line': node.lineno, 'ch': node.col_offset}
    me = {'type': myType, 'start': myStart, 'end': None, 'content': myContent}
    return me

In [161]:
'''
Assign(expr* targets, expr value)
'''
def visitAssign(node, text, textStart, textEnd):
    me = stmtOrExpr(node)
    end = textStart
    for target in node.targets:
        end, stuff = captureStuff(text, end, target)
        me['content'] += stuff
    end, value = visit(node.value, text, end, textEnd, None)
    me['content'].append(value)
    me['end'] = posFromText(text, end)
    print("MADE:", ast.dump(node, True, False), "\n",me,"\n")
    return end, me
    

In [162]:
'''
Call(expr func, expr* args, keyword* keywords)
'''
def visitCall(node, text, textStart, textEnd):
    me = stmtOrExpr(node)
    end = textStart
    end, value = visit(node.func, text, end, textEnd, None)
    me['content'].append(value)
    end, symbols = getPunctuationBetween(text, end, ")")
    me['content'] += symbols
    for argument in node.args:
        end, stuff = captureStuff(text, end, argument, ")")
        me['content'] += stuff
    for keyword in node.keywords:
        end, stuff = captureStuff(text, end, keyword, ")")
        me['content'] += stuff
    end, symbols = getPunctuationBetween(text, end, ")")
    me['content'] += symbols
    me['content'].append({'syntok': ')'})
    end += 1
    me['end'] = posFromText(text, end)
    print("MADE:", ast.dump(node, True, False), "\n",me,"\n")
    return end, me
    

In [163]:
def visitAttribute(node, text, textStart, textEnd):
    myType = type(node).__name__
    myContent = []
    end = textStart
    end, value = visit(node.value, text, textStart, textEnd, None)
    myContent.append(value)
    myContent.append({'syntok': '.'})
    attr = str(node.attr)
    myContent.append({'syntok': attr})
    myStart = value['start']
    myEnd = {'line': myStart['line'], 'ch': value['end']['ch'] + 1 + len(attr)}
    me = {'type': myType, 'start': myStart, 'end': myEnd, 'content': myContent}
    end += 1 + len(attr)
    print("MADE:", end, ast.dump(node, True, False), "\n",me,"\n")
    return end, me

def visitSubscript(node, text, textStart, textEnd):
    myType = type(node).__name__
    myContent = []
    myStart = findNodeStart(node)
    end, value = visit(node.value, text, textStart, textEnd, None)
    end, slicey = visit(node.slice, text, end, textEnd, None)
    myContent.append(value)
    myContent.append(slicey)
    myEnd = slicey['end']
    me = {'type': myType, 'start': myStart, 'end': myEnd, 'content': myContent}
    print("MADE:", ast.dump(node, True, False), "\n",me,"\n")
    return end, me

In [164]:
'''
slice = Slice(expr? lower, expr? upper, expr? step)
          | ExtSlice(slice* dims)
          | Index(expr value)
'''
def visitIndex(node, text, textStart, textEnd):
    myType = type(node).__name__
    myContent = []
    myContent.append({'syntok': '['})
    end = textStart + 1
    print("Text is:", text[end:textEnd])
    end, symbols = getPunctuationBetween(text,end)
    myContent += symbols
    end, value = visit(node.value, text, end, textEnd, None)
    myContent.append(value)
    end, symbols = getPunctuationBetween(text,end, ']')
    myContent += symbols
    myContent.append({'syntok': ']'})
    end += 1
    
    myStart = {'line': value['start']['line'], 'ch': value['start']['ch'] - 1}
    myEnd = {'line': myStart['line'], 'ch': myStart['ch'] + 1 }
    me = {'type': myType, 'start': myStart, 'end': myEnd, 'content': myContent}
    print("MADE:", me,"\n")
    return end, me

In [165]:
'''
 -- keyword arguments supplied to call (NULL identifier for **kwargs)
    keyword = (identifier? arg, expr value)
'''
def visitKeyword(node, text, textStart, textEnd):
    print("TEXT IS:", text[textStart:textEnd+1])
    myType = type(node).__name__
    myContent = []
    myStart = posFromText(text, textStart)
    end = textStart
    if(node.arg): 
        arg = str(node.arg)
        myContent.append({'syntok':arg})
        end += len(arg)
        end, symbols = getPunctuationBetween(text, end + 1)
        myContent += symbols
    print("KEYword content", myContent, end)
    end, value = visit(node.value, text, end, textEnd, None)
    myContent.append(value)
    myEnd = value['end']
    me = {'type': myType, 'start': myStart, 'end': myEnd, 'content': myContent}
    print("MADE:", me,"\n")
    return end, me
    

In [166]:
def visit(node, text, textStart, textEnd, nextNode):
    
    # 1. first, figure out if we're dealing with a literal or parent
    children = list(ast.iter_child_nodes(node))
    
    # necissary to filter children using findNextChild, since there's some
    # metalabels like store or load we don't care about here
    child, child_itr = findNextChild(children, -1)
    print("\n",type(node).__name__, children, textStart, text[textStart])
    
    if not child: # LITERAL
        return visitLiteral(node, text, textStart)
    
    myType = type(node).__name__
    
    visitors = {"Module": visitModule,
                "Interactive": visitModule,
                "Expression": visitModule,
                "Suite": visitModule,
                "Assign": visitAssign,
                "Attribute": visitAttribute,
                "Subscript": visitSubscript,
                "Index": visitIndex,
                "keyword": visitKeyword,
                "Call": visitCall}
    
    if myType in visitors:
        return visitors[myType](node, text, textStart, textEnd)
    else:
        print("NO VISITOR FOR "+myType)
        return genericVisit(node, text, textStart, textEnd, nextNode)
    
    
def visitLiteral(node, text, start):
    myType = type(node).__name__
    myStart = {'line': node.lineno, 'ch': node.col_offset}
    end, myEnd, myLiteral = findLiteralEnd(text[start:], myStart)
    end += start
    if(isinstance(myLiteral, str)):
        return (end, {'type': myType, 'start': myStart, 'end': myEnd, 'literal': myLiteral})
    else: # actually a list of syntok
        return (end, {'type': myType, 'start': myStart, 'end': myEnd, 'content': myLiteral})

In [172]:
def getPunctuationBetween(text, textStart, stopChar = "", allowNewline = False):
    textEnd = len(text)
    i = textStart
    char = text[textStart]
    content = []
    extra = []
    if(allowNewline): extra = newline
    while i < textEnd and (char in punctuation or char in extra) and char != stopChar:
        if char == "#":
            new_i, comment = captureComment(text, i, textEnd)
            content.append({'syntok': str(comment)})
            i = new_i - 1
        else:
            content.append({"syntok": str(char)})
        i += 1
        char = str(text[i])
        
    return i, content


In [173]:
def parse(text):
    node = ast.parse(text)
    print(ast.dump(node, True, True))
    print("\n",visit(node, text, 0, len(text), None))

In [174]:
text = """# TODO: Select three indices of your choice you wish to sample from the dataset
indices = []

# Create a DataFrame of the chosen samples
samples = pd.DataFrame(data.loc[indices], columns = data.keys()).reset_index(drop = True)
print("Chosen samples of wholesale customers dataset:")
display(samples)"""


sqParens = set(["[","]"])
parens = set(["(",")"])
brackets = set(["{","}"])
spaces = set(["\t", " "])
newline = set(["\n"]) #todo may vary across platforms
punctuation = set(string.punctuation)
punctuation.add(" ")


# to hurry up, reduce ast at this stage?
# match parens [] {} () otherwise those can end up in weird places

parse(text)
#print(json.dumps(main(l, tree),  indent=2))

Module(body=[Assign(targets=[Name(id='indices', ctx=Store(), lineno=2, col_offset=0)], value=List(elts=[], ctx=Load(), lineno=2, col_offset=10), lineno=2, col_offset=0), Assign(targets=[Name(id='samples', ctx=Store(), lineno=5, col_offset=0)], value=Call(func=Attribute(value=Call(func=Attribute(value=Name(id='pd', ctx=Load(), lineno=5, col_offset=10), attr='DataFrame', ctx=Load(), lineno=5, col_offset=10), args=[Subscript(value=Attribute(value=Name(id='data', ctx=Load(), lineno=5, col_offset=23), attr='loc', ctx=Load(), lineno=5, col_offset=23), slice=Index(value=Name(id='indices', ctx=Load(), lineno=5, col_offset=32)), ctx=Load(), lineno=5, col_offset=23)], keywords=[keyword(arg='columns', value=Call(func=Attribute(value=Name(id='data', ctx=Load(), lineno=5, col_offset=52), attr='keys', ctx=Load(), lineno=5, col_offset=52), args=[], keywords=[], lineno=5, col_offset=52))], lineno=5, col_offset=10), attr='reset_index', ctx=Load(), lineno=5, col_offset=10), args=[], keywords=[keyword(ar

NameError: name 'genericVisit' is not defined