In [1]:
import ast
from structlog import get_logger
import pandas as pd
import numpy as np

In [5]:
datas = pd.DataFrame(np.random.randint(0, 10, size=(10, 2)), columns=[f'col_{n+1}' for n in range(2)])

In [6]:
logger = get_logger()

In [7]:
import operator

In [None]:
def eval_expression(node: ast.AST):
    log = logger.bind(func='eval expression')
    log.msg(node)
    return eval_node(node.body)

def eval_constant(node: ast.AST):
    log = logger.bind(func='eval constant')
    log.msg(node)
    return node.value

def eval_name(node: ast.AST):
    log = logger.bind(func='eval name')
    log.msg(node)
    if node.id in datas.columns:
        return datas[node.id]    

def eval_binop(node: ast.AST):
    log = logger.bind(func='eval binop')
    log.msg(node)
    left_value = eval_node(node.left)
    right_value = eval_node(node.right)
    apply = BINARY_OPERATIONS[type(node.op)]
    return apply(left_value, right_value)

def eval_unaryop(node: ast.AST):
    log = logger.bind(func='eval unaryop')
    log.msg(node)
    operand_value = eval_node(node.operand, vars)
    apply = UNARY_OPERATIONS[type(node.op)]
    return apply(operand_value)

In [None]:
def eval_node(node: ast.AST):
    log = logger.bind(func='eval node')
    log.msg(node)
    node_type = type(node)
    assert node_type in EVALUATORS.keys()
    
    func = EVALUATORS[node_type]
    return func(node)

In [228]:
from pydantic import BaseModel
from pydantic.validators import Callable, Any
import json

In [229]:
class Node(BaseModel):
    def __repr__(self):
        return f'<{self.__class__.__name__}>'
    
    def eval(self, *args):
        raise NotImplementedError

class TerminaryNode(Node):
    value: Any
        
    def __repr__(self):
        return f'{super().__repr__()} value={self.value}'
    
class UnaryNode(Node):
    value: Node
    func: Callable
        
    def eval(self, datas: pd.DataFrame):
        return self.func(self.value.eval(datas))

class BinaryNode(Node):
    left: Node
    right: Node
    func: Callable
        
    def eval(self, datas: pd.DataFrame):
        return self.func(self.left.eval(datas), self.right.eval(datas))
        
    def __repr__(self):
        return (f"""{super().__repr__()} {self.func}"""
                f"""\n\tleft={self.left.__repr__()}\n\tright={self.right.__repr__()}""")
        
class ConstantNode(TerminaryNode):
    def eval(self, datas: pd.DataFrame):
        return self.value
        
class VariableNode(TerminaryNode):
    def eval(self, datas: pd.DataFrame):
        return datas[self.value]

In [270]:
import math
import numpy as np

In [271]:
np.sin(datas['col_1'])

0    0.909297
1    0.000000
2    0.412118
3    0.841471
4    0.989358
5   -0.756802
6    0.909297
7    0.412118
8    0.412118
9   -0.756802
Name: col_1, dtype: float64

In [242]:
from typing import Optional


def eval_node(ast_node: ast.AST):
    log = logger.bind(func='eval node')
    log.msg(ast_node)
    func = EVALUATORS[type(ast_node)]
    return func(ast_node)
    
def eval_binop(ast_node: ast.AST):
    log = logger.bind(func='eval binop')
    log.msg(ast.dump(ast_node))
    op=BINARY_OPERATIONS[type(ast_node.op)]
    return BinaryNode(left=eval_node(ast_node.left), right=eval_node(ast_node.right), func=op)
    
def eval_constant(ast_node: ast.AST):
    log = logger.bind(func='eval constant')
    log.msg(ast.dump(ast_node))
    return ConstantNode(value=ast_node.value)

def eval_name(ast_node: ast.AST):
    log = logger.bind(func='eval name')
    log.msg(ast.dump(ast_node))
    return VariableNode(value=ast_node.id)
    
def eval_unaryop(ast_node: ast.AST):
    log = logger.bind(func='eval unary')
    log.msg(ast.dump(ast_node))
    op=UNARY_OPERATIONS[type(ast_node.op)]
    return UnaryNode(func=op, value=eval_node(ast_node.operand))

def build_tree(ast_node: ast.AST):
    log = logger.bind(func='build tree')
    node_value = ast_node.body if type(ast_node) == ast.Expression else ast_node
    return eval_node(node_value)
        

In [263]:
EVALUATORS = {
        # ast.Expression: eval_expression,
        ast.Constant: eval_constant,
        ast.Name: eval_name,
        ast.BinOp: eval_binop,
        ast.UnaryOp: eval_unaryop,
    }

BINARY_OPERATIONS = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.Pow: operator.pow
}

UNARY_OPERATIONS = {
    ast.USub: operator.neg,
    ast.UAdd: lambda x: x
}

In [264]:
# rule = "(col_1 + 1)**2 - (2 * 3 / -4) + (col_2 * 4) + col_1 ** 2 - col_2 / 4 + col_1**(+0.4)"
# rule = "(col_1 + 2)**2 - col_2"
ast_node = ast.parse(rule, '<string>', mode='eval')

In [265]:
node = build_tree(ast_node)

2022-10-30 21:17.15 [info     ] <ast.BinOp object at 0x0000017B94BA01F0> func=eval node
2022-10-30 21:17.15 [info     ] BinOp(left=BinOp(left=BinOp(left=BinOp(left=BinOp(left=BinOp(left=BinOp(left=Name(id='col_1', ctx=Load()), op=Add(), right=Constant(value=1)), op=Pow(), right=Constant(value=2)), op=Sub(), right=BinOp(left=BinOp(left=Constant(value=2), op=Mult(), right=Constant(value=3)), op=Div(), right=UnaryOp(op=USub(), operand=Constant(value=4)))), op=Add(), right=BinOp(left=Name(id='col_2', ctx=Load()), op=Mult(), right=Constant(value=4))), op=Add(), right=BinOp(left=Name(id='col_1', ctx=Load()), op=Pow(), right=Constant(value=2))), op=Sub(), right=BinOp(left=Name(id='col_2', ctx=Load()), op=Div(), right=Constant(value=4))), op=Add(), right=BinOp(left=Name(id='col_1', ctx=Load()), op=Pow(), right=UnaryOp(op=UAdd(), operand=Constant(value=0.4)))) func=eval binop
2022-10-30 21:17.15 [info     ] <ast.BinOp object at 0x0000017B94BA0E50> func=eval node
2022-10-30 21:17.15 [info     ] 

In [266]:
node.eval(datas)

0     23.319508
1     17.500000
2    196.158225
3     30.000000
4    148.797397
5     47.991101
6     27.069508
7    199.908225
8    211.158225
9     44.241101
dtype: float64

In [267]:
node

<BinaryNode> <built-in function add>
	left=<BinaryNode> <built-in function sub>
	left=<BinaryNode> <built-in function add>
	left=<BinaryNode> <built-in function add>
	left=<BinaryNode> <built-in function sub>
	left=<BinaryNode> <built-in function pow>
	left=<BinaryNode> <built-in function add>
	left=<VariableNode> value=col_1
	right=<ConstantNode> value=1
	right=<ConstantNode> value=2
	right=<BinaryNode> <built-in function truediv>
	left=<BinaryNode> <built-in function mul>
	left=<ConstantNode> value=2
	right=<ConstantNode> value=3
	right=<UnaryNode>
	right=<BinaryNode> <built-in function mul>
	left=<VariableNode> value=col_2
	right=<ConstantNode> value=4
	right=<BinaryNode> <built-in function pow>
	left=<VariableNode> value=col_1
	right=<ConstantNode> value=2
	right=<BinaryNode> <built-in function truediv>
	left=<VariableNode> value=col_2
	right=<ConstantNode> value=4
	right=<BinaryNode> <built-in function pow>
	left=<VariableNode> value=col_1
	right=<UnaryNode>

In [233]:
ast.dump(ast_node)

"Expression(body=UnaryOp(op=USub(), operand=BinOp(left=BinOp(left=Name(id='col_1', ctx=Load()), op=Add(), right=Constant(value=2)), op=Pow(), right=Constant(value=2))))"

In [65]:
node.body.left.id

'col_1'