Skip to content

Commit

Permalink
Merge 3ae5747 into 3c07c25
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamovar committed Jun 22, 2013
2 parents 3c07c25 + 3ae5747 commit 375a1ea
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 0 deletions.
162 changes: 162 additions & 0 deletions brian2/codegen/ast_parser.py
@@ -0,0 +1,162 @@
import ast

__all__ = ['NodeRenderer',
'NumpyNodeRenderer',
'CPPNodeRenderer',
]

class NodeRenderer(object):
expression_ops = {
# BinOp
'Add': '+',
'Sub': '-',
'Mult': '*',
'Div': '/',
'Pow': '**',
'Mod': '%',
# Compare
'Lt': '<',
'LtE': '<=',
'Gt': '>',
'GtE': '>=',
'Eq': '==',
'NotEq': '!=',
# Unary ops
'Not': 'not',
'Invert': '~',
'UAdd': '+',
'USub': '-',
# Bool ops
'And': 'and',
'Or': 'or',
}

def render_expr(self, expr):
node = ast.parse(expr, mode='eval')
return self.render_node(node.body)

def render_code(self, code):
lines = []
for node in ast.parse(code).body:
lines.append(self.render_node(node))
return '\n'.join(lines)

def render_node(self, node):
nodename = node.__class__.__name__
methname = 'render_'+nodename
if not hasattr(self, methname):
raise SyntaxError("Unknown syntax: "+nodename)
return getattr(self, methname)(node)

def render_Name(self, node):
return node.id

def render_Num(self, node):
return repr(node.n)

def render_Call(self, node):
if len(node.keywords):
raise ValueError("Keyword arguments not supported.")
elif node.starargs is not None:
raise ValueError("*args not supported")
elif node.kwargs is not None:
raise ValueError("**kwds not supported")
return '%s(%s)' % (self.render_node(node.func),
', '.join(self.render_node(arg) for arg in node.args))

def render_BinOp_parentheses(self, left, right, op):
# This function checks whether or not you can ommit parentheses assuming Python
# precedence relations, hopefully this is the same in C++ and Java, but we'll need
# to check it
exprs = ['%s %s %s', '(%s) %s %s', '%s %s (%s)', '(%s) %s (%s)']
nr = NodeRenderer()
L = nr.render_node(left)
R = nr.render_node(right)
O = NodeRenderer.expression_ops[op.__class__.__name__]
refexpr = '(%s) %s (%s)' % (L, O, R)
refexprdump = ast.dump(ast.parse(refexpr))
for expr in exprs:
e = expr % (L, O, R)
if ast.dump(ast.parse(e))==refexprdump:
return expr % (self.render_node(left),
self.expression_ops[op.__class__.__name__],
self.render_node(right),
)

def render_BinOp(self, node):
return self.render_BinOp_parentheses(node.left, node.right, node.op)

def render_BoolOp(self, node):
# TODO: for the moment we always parenthesise boolean ops because precedence
# might be different in different languages and it's safer - also because it's
# a bit more complicated to write the parenthesis rule
op = node.op
left = node.values[0]
remaining = node.values[1:]
while len(remaining):
right = remaining[0]
remaining = remaining[1:]
s = self.render_BinOp_parentheses(left, right, op)
op = self.expression_ops[node.op.__class__.__name__]
return (' '+op+' ').join('(%s)' % self.render_node(v) for v in node.values)

def render_Compare(self, node):
if len(node.comparators)>1:
raise SyntaxError("Can only handle single comparisons like a<b not a<b<c")
return self.render_BinOp_parentheses(node.left, node.comparators[0], node.ops[0])

def render_UnaryOp(self, node):
return '%s(%s)' % (self.expression_ops[node.op.__class__.__name__],
self.render_node(node.operand))

def render_Assign(self, node):
if len(node.targets)>1:
raise SyntaxError("Only support syntax like a=b not a=b=c")
return '%s = %s' % (self.render_node(node.targets[0]),
self.render_node(node.value))


class NumpyNodeRenderer(NodeRenderer):
expression_ops = NodeRenderer.expression_ops.copy()
expression_ops.update({
# Unary ops
'Not': 'logical_not',
'Invert': 'logical_not',
# Bool ops
'And': '*',
'Or': '+',
})


class CPPNodeRenderer(NodeRenderer):
expression_ops = NodeRenderer.expression_ops.copy()
expression_ops.update({
# Unary ops
'Not': '!',
'Invert': '!',
# Bool ops
'And': '&&',
'Or': '||',
})

def render_BinOp(self, node):
if node.op.__class__.__name__=='Pow':
return 'pow(%s, %s)' % (self.render_node(node.left),
self.render_node(node.right))
else:
return NodeRenderer.render_BinOp(self, node)

def render_Assign(self, node):
return NodeRenderer.render_Assign(self, node)+';'


if __name__=='__main__':
# print precedence(ast.parse('c(d)**2 and 3', mode='eval').body)
# print NodeRenderer().render_expr('a-(b-c)+d')
# print NodeRenderer().render_expr('a and b or c')
for renderer in [NodeRenderer(), NumpyNodeRenderer(), CPPNodeRenderer()]:
name = renderer.__class__.__name__
print name+'\n'+'='*len(name)
print renderer.render_expr('a+b*c(d, e)+e**f')
print renderer.render_expr('a and -b and c and 1.2')
print renderer.render_code('a=b\nc=d+e')
99 changes: 99 additions & 0 deletions brian2/tests/test_syntax_translation.py
@@ -0,0 +1,99 @@
'''
Tests the brian2.codegen.syntax package
'''
from brian2.utils.stringtools import get_identifiers
from brian2.codegen.ast_parser import (NodeRenderer, NumpyNodeRenderer,
CPPNodeRenderer,
)
from numpy.testing import assert_raises, assert_equal
from numpy.random import rand, randint
import numpy as np
from scipy import weave
import nose

def generate_expressions(N=100, numvars=5, numfloats=1, numints=1, complexity=5, depth=3):
ops = ['+', '*', '-', '/', '**']
vars = [chr(ord('a')+i) for i in xrange(numvars)]
consts = [rand() for _ in xrange(numfloats)]+range(1, 1+numints)
varsconsts = [str(x) for x in vars+consts]
for _ in xrange(N):
expr = 'a'
for _ in xrange(depth):
s = 'a'
for _ in xrange(complexity):
op = ops[randint(len(ops))]
var = vars[randint(numvars)]
s = s+op+var
op = ops[randint(len(ops))]
expr = '(%s)%s(%s)'%(expr, op, s)
yield (vars, [], expr)


def parse_expressions(renderer, evaluator, numvalues=10):
exprs = list(generate_expressions())
additional_exprs = '''
a<b
a<=b
a>b
a>=b
a==b
a!=b
a+1
1+a
a%2
a%2.0
1+3
a>1 and b>1
'''
exprs = exprs+[('abc', [], l.strip()) for l in additional_exprs.split('\n') if l.strip()]
for varids, funcids, expr in exprs:
pexpr = renderer.render_expr(expr)
n = 0
for _ in xrange(numvalues):
# assign some random values
ns = dict((v, rand()) for v in varids)
try:
r1 = eval(expr, ns)
except (ZeroDivisionError, ValueError, OverflowError):
continue
n += 1
r2 = evaluator(pexpr, ns)
assert_equal(r1, r2)


def numpy_evaluator(expr, ns):
ns = ns.copy()
for k in ns.keys():
if not k.startswith('_'):
ns[k] = np.array([ns[k]])
x = eval(expr, ns)
if isinstance(x, np.ndarray):
return x[0]
else:
return x


def cpp_evaluator(expr, ns):
return weave.inline('return_val = %s;' % expr, ns.keys(), local_dict=ns,
compiler='gcc')


def test_parse_expressions_python():
parse_expressions(NodeRenderer(), eval)


def test_parse_expressions_numpy():
parse_expressions(NumpyNodeRenderer(), numpy_evaluator)


def test_parse_expressions_cpp():
# Skipy this test because we haven't handled e.g. 1.2%2 or 2%1.3 yet
raise nose.SkipTest()
parse_expressions(CPPNodeRenderer(), cpp_evaluator)


if __name__=='__main__':
test_parse_expressions_python()
test_parse_expressions_numpy()
test_parse_expressions_cpp()

0 comments on commit 375a1ea

Please sign in to comment.