In [13]:
import ast
import re

from geometry import SO2, SO3, SE2, SE3
import numpy as np

In [22]:
# Define mappings from Python to C++
cpp_function_map = {
    "mylib.add": "mylib::add",
    "mylib.multiply": "mylib::multiply"
}
cpp_operator_map = {
    ("mylib.Number", "mylib.Number", ast.Add): ("operator+", "RET"),
    ("mylib.Number", "float", ast.Add): ("operator+", "RET"),
    ("float", "mylib.Number", ast.Add): ("operator+", "RET"),
    ("mylib.Number", "mylib.Number", ast.Mult): ("operator*", "RET"),
    ("float", "mylib.Number", ast.Mult): ("operator*", "RET")
}
def npndarrayToEigen(
cpp_type_map = {
    "np.ndarray": ""
}

# class Operation(object):
#     TYPE_ASSIGN = 0
#     TYPE_MUTATION = 1
#     TYPE_RETURN = 2
#     def __init__(self):
#         self.rendered = None
#         self.type = None
#         self.left = None
#         self.right = None

class FunctionTransformer(ast.NodeVisitor):
    def __init__(self):
        self.functions = []
        self.variables = {}

    def visit_FunctionDef(self, node):
        func_name = node.name
        args = [(arg.arg, self.get_annotation(arg.annotation)) for arg in node.args.args]
        self.variables = {arg[0]: arg[1] for arg in args}
        return_type = self.get_annotation(node.returns)
        body = self.parse_body(node.body)
        self.functions.append({"name": func_name, "args": args, "body": body, "return_type": return_type})

    def get_annotation(self, annotation):
        """Extracts type annotations."""
        if isinstance(annotation, ast.Name):  # Basic types (e.g., int, float)
            return annotation.id
        elif isinstance(annotation, ast.Attribute):  # Qualified names (e.g., mylib.Number)
            return f"{annotation.value.id}.{annotation.attr}"
        elif isinstance(annotation, ast.Constant):  # String-based annotations (e.g., "mylib.Number")
            return annotation.value
        return "unknown"
    
    def parse_body(self, body):
        lines = []
        for stmt in body:
            if isinstance(stmt, ast.Return):
                lines.append("return " + self.parse_expr(stmt.value)[0] + ";")
            elif isinstance(stmt, ast.AnnAssign):
                annotated_type = self.get_annotation(stmt.annotation)
                self.variables[stmt.target.id] = annotated_type
                lines.append(stmt.target.id + " = " + self.parse_expr(stmt.value)[0] + ";")
            elif isinstance(stmt, ast.Assign):
                print("HA")
            else:
                print(stmt)
        return "\n    ".join(lines)

    def parse_expr(self, node):
        if isinstance(node, ast.BinOp):  # Handle binary operations (a + b, a * b)
            left = self.parse_expr(node.left)[0]
            right = self.parse_expr(node.right)[0]
            left_type = self.get_expr_type(node.left)
            right_type = self.get_expr_type(node.right)
            op_type = type(node.op)
            cpp_op, out_type = cpp_operator_map.get((left_type, right_type, op_type), "UNKNOWN_OP")
            opmatch = re.search(re.compile(r"operator(.*)"), cpp_op)
            if opmatch:
                return f"({left} {opmatch.group(1)} {right})", out_type
            else:
                return f"{left}.{cpp_op}({right})", out_type
        elif isinstance(node, ast.Call):  # Function call
            func_name = self.get_func_name(node.func)
            args = ", ".join(self.parse_expr(arg)[0] for arg in node.args)
            return f"{cpp_function_map.get(func_name, func_name)}({args})", None
        elif isinstance(node, ast.Name):  # Variable name
            return node.id, None
        elif isinstance(node, ast.Constant):  # Constants like numbers
            return str(node.value), None
        return "UNKNOWN_EXPR", None

    def get_func_name(self, node):
        if isinstance(node, ast.Attribute):  # e.g., mylib.add
            return f"{node.value.id}.{node.attr}"
        elif isinstance(node, ast.Name):
            return node.id
        return "UNKNOWN_FUNC"

    def get_expr_type(self, node):
        if isinstance(node, ast.Name):
            return self.variables.get(node.id, "unknown")
        elif isinstance(node, ast.Constant):
            return "float" if isinstance(node.value, (int, float)) else "unknown"
        return "unknown"

# Load and parse Python function
python_code = """
def quat_functor(q1: SO3, q2: SO3) -> SO3:
    v: np.ndarray = q2 - q1
    return q1 + v
"""

tree = ast.parse(python_code)
transformer = FunctionTransformer()
transformer.visit(tree)

print(transformer.functions)

[{'name': 'quat_functor', 'args': [('q1', 'SO3'), ('q2', 'SO3')], 'body': 'v = q2.UNKNOWN_OP(q1);\n    return q1.UNKNOWN_OP(v);', 'return_type': 'SO3'}]


In [7]:
def quat_functor(q1: SO3, q2: SO3) -> np.ndarray:
    return q2 - q1