In [114]:
import lark

grammar = """
start:      expression

expression: term   | term "+" term     -> add | term "-" term     -> sub
term:       factor | factor "*" factor -> mul | factor "/" factor -> div
factor:     power  | "+" factor        -> pos | "-" factor        -> neg
power:      call ["**" factor]
call:       atom   | call trailer
atom:       "(" expression ")" | CNAME -> symbol | NUMBER -> literal

trailer:    "(" arglist ")"
arglist:    expression ("," expression)*

%import common.CNAME
%import common.NUMBER
%import common.WS

%ignore WS
"""

parser = lark.Lark(grammar)


In [115]:
print(parser.parse("2 + 2").pretty())

start
  add
    term
      factor
        power
          call
            literal	2
    term
      factor
        power
          call
            literal	2



In [52]:
class AST:
    _fields = ()
    def __init__(self, *args, line=None):
        for n, x in zip(self._fields, args):
            setattr(self, n, x)
        self.line = line
    def __repr__(self):
        return "{0}({1})".format(type(self).__name__, ", ".join(repr(getattr(self, n)) for n in self._fields))

class Literal(AST):
    _fields = ("value",)
    def __str__(self):
        return str(self.value)

class Symbol(AST):
    _fields = ("symbol",)
    def __str__(self):
        return self.symbol

class Call(AST):
    _fields = ("function", "arguments")
    def __str__(self):
        return "{0}({1})".format(str(self.function), ", ".join(str(x) for x in self.arguments))

In [122]:
def toast(ptnode):
    if ptnode.data in ("add", "sub", "mul", "div", "pos", "neg"):
        return Call(Symbol(ptnode.data), [toast(x) for x in ptnode.children])
    elif ptnode.data == "power" and len(ptnode.children) == 2:
        return Call(Symbol("pow"), [toast(ptnode.children[0]), toast(ptnode.children[1])])
    elif ptnode.data == "call" and len(ptnode.children) == 2:
        return Call(toast(ptnode.children[0]), toast(ptnode.children[1]))
    elif ptnode.data == "symbol":
        return Symbol(ptnode.children[0])
    elif ptnode.data == "literal":
        return Literal(float(ptnode.children[0]))
    elif ptnode.data == "arglist":
        return [toast(x) for x in ptnode.children]
    else:
        return toast(ptnode.children[0])    # many other cases, all of them simple pass-throughs

print(toast(parser.parse("2 + 2")))


add(2.0, 2.0)


In [128]:
class SymbolTable:
    def __init__(self, parent=None, **symbols):
        self.parent = parent
        self.symbols = symbols

    def __getitem__(self, symbol):
        if symbol in self.symbols:
            return self.symbols[symbol]
        elif self.parent is not None:
            return self.parent[symbol]
        else:
            raise KeyError(symbol)

    def __setitem__(self, symbol, value):
        self.symbols[symbol] = value

builtins = SymbolTable()
builtins["add"] = lambda x, y: x + y
builtins["sub"] = lambda x, y: x - y
builtins["mul"] = lambda x, y: x * y
builtins["div"] = lambda x, y: x / y
builtins["pos"] = lambda x: x
builtins["neg"] = lambda x: -x
builtins["pow"] = lambda x, y: x**y


In [139]:
def interpreter(astnode, symboltable):
    if isinstance(astnode, Literal):
        return astnode.value
    elif isinstance(astnode, Symbol):
        return symboltable[astnode.symbol]
    elif isinstance(astnode, Call):
        return interpreter(astnode.function, symboltable)(*[interpreter(x, symboltable) for x in astnode.arguments])

interpreter(toast(parser.parse("2 + 2")), SymbolTable(builtins))


4.0