In [299]:
#importing the necessary libraries and setup
import google.generativeai as genai
import sys
import re
import uuid
from textwrap import dedent
from typing import Dict
from typing import List as pyList
from typing import Optional
from typing import Tuple as pyTuple
from typing import Type, Union, cast, get_args

from mypy import build
from mypy.defaults import PYTHON3_VERSION
from mypy.modulefinder import BuildSource
from mypy.nodes import (
    AssignmentStmt,
    Block,
    CallExpr,
    ComparisonExpr,
    ConditionalExpr,
    FuncDef,
    IndexExpr,
    IntExpr,
    LambdaExpr,
    ListExpr,
    MemberExpr,
    MypyFile,
    NameExpr,
    Node,
    OpExpr,
    ReturnStmt,
    SliceExpr,
)
from mypy.options import Options
from mypy.types import AnyType, CallableType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeList, UnboundType

from metalift.ir import (
    Add,
    And,
    Bool,
    Div,
    Eq,
    Expr,
    Fn,
    FnDeclRecursive,
    Ge,
    Gt,
    Int,
    Le,
    List,
    Lt,
    Mod,
    Mul,
    Object,
    ObjectT,
    Or,
    Sub,
    Matrix,
    call,
    create_object,
    fn_decl_recursive,
    is_fn_decl_type,
    is_list_type,
    is_nested_list_type,
    ite,
    Not
)

from dotenv import load_dotenv
load_dotenv()
import os
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))

#Add project files to be imported
script_dir = os.getcwd()

project_root = os.path.join(script_dir, '..')

sys.path.append(os.path.abspath(project_root))

In [410]:
# file_path = "../benchmark_transpilation/linear-algebra/kernels/2mm/2mm.c"
file_path = "test.c"
operator_path = "./dsl_operators.py"
file_content = ""
python_operators = ""
with open(operator_path, "r") as file:
    # Read the file content into a string
    python_operators = file.read()
# Open the file using the path
with open(file_path, "r") as file:
    # Read the file content into a string
    file_content = file.read()

# Now `file_content` holds the content of the file as a string
    def remove_c_comments(code: str) -> str:
        # Remove single-line comments
        code = re.sub(r'//.*?\n', '', code)
        # Remove multi-line comments
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
        return code

def prepare_prompt(c_code: str) -> str:
    prompt = f"""
Task: Your task is to rewrite the given C benchmark code in Python. You need to use only the set of 
provided functions and constants to achieve this. The rewritten program should be semantically 
equivalent to the benchmark.
#Instructions
1. The rewritten program should just be a single return statement of the form return_var = provided_function()
2. Do not use for/while loops for rewriting the function
3. Inline all the expressions, Do not use intermediate variables
4. The rewritten program should only use built-in Python functions and the provided functions and constants

----------------------------Defined Functions---------------------------- 
{python_operators}


----------------------------Function to Rewrite----------------------------
{c_code}
"""
    return prompt

print(prepare_prompt(remove_c_comments(file_content)))


Task: Your task is to rewrite the given C benchmark code in Python. You need to use only the set of 
provided functions and constants to achieve this. The rewritten program should be semantically 
equivalent to the benchmark.
#Instructions
1. The rewritten program should just be a single return statement of the form return_var = provided_function()
2. Do not use for/while loops for rewriting the function
3. Inline all the expressions, Do not use intermediate variables
4. The rewritten program should only use built-in Python functions and the provided functions and constants

----------------------------Defined Functions---------------------------- 
from typing import Callable, List

def matrix_elemwise_add(
    matrix_x: List[List[int]], matrix_y: List[List[int]]
) -> List[List[int]]:
    return (
        []
        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)
        else [
            vec_elemwise_add(matrix_x[0], matrix_y[0]),
            *matrix_elemwise_add(matrix_x

In [433]:
def call_llm(prompt: str) -> str:
    # Call the Gemini API with the given prompt
    model = genai.GenerativeModel("gemini-1.5-pro")
    response = model.generate_content(prompt)
    print(response.text)
    return response.text

call_llm(prepare_prompt(remove_c_comments(file_content)))

```python
from typing import Callable, List

def matrix_elemwise_add(
    matrix_x: List[List[int]], matrix_y: List[List[int]]
) -> List[List[int]]:
    return (
        []
        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)
        else [
            vec_elemwise_add(matrix_x[0], matrix_y[0]),
            *matrix_elemwise_add(matrix_x[1:], matrix_y[1:]),
        ]
    )

def matrix_elemwise_sub(
    matrix_x: List[List[int]], matrix_y: List[List[int]]
) -> List[List[int]]:
    return (
        []
        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)
        else [
            vec_elemwise_sub(matrix_x[0], matrix_y[0]),
            *matrix_elemwise_sub(matrix_x[1:], matrix_y[1:]),
        ]
    )


def matrix_elemwise_mul(
    matrix_x: List[List[int]], matrix_y: List[List[int]]
) -> List[List[int]]:
    return (
        []
        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)
        else [
            vec_elemwise_mul(matrix_x[0], matrix_y[0])

'```python\nfrom typing import Callable, List\n\ndef matrix_elemwise_add(\n    matrix_x: List[List[int]], matrix_y: List[List[int]]\n) -> List[List[int]]:\n    return (\n        []\n        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)\n        else [\n            vec_elemwise_add(matrix_x[0], matrix_y[0]),\n            *matrix_elemwise_add(matrix_x[1:], matrix_y[1:]),\n        ]\n    )\n\ndef matrix_elemwise_sub(\n    matrix_x: List[List[int]], matrix_y: List[List[int]]\n) -> List[List[int]]:\n    return (\n        []\n        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)\n        else [\n            vec_elemwise_sub(matrix_x[0], matrix_y[0]),\n            *matrix_elemwise_sub(matrix_x[1:], matrix_y[1:]),\n        ]\n    )\n\n\ndef matrix_elemwise_mul(\n    matrix_x: List[List[int]], matrix_y: List[List[int]]\n) -> List[List[int]]:\n    return (\n        []\n        if len(matrix_x) < 1 or not len(matrix_x) == len(matrix_y)\n        else [\n            vec_el

In [391]:
def get_ps_sols(n, source_code, incorrect_ps_sols):
    prompt = prepare_prompt(remove_c_comments(source_code))
    ps_sols = []
    for _ in range(n):
        response = call_llm(prompt)
        ps_sol = response.text.strip()
        if ps_sol not in incorrect_ps_sols:
            ps_sols.append(ps_sol)

    return ps_sols
    


In [427]:
parse_path = "./testResult.py"
parse_result = ""
with open(parse_path, "r") as file:
    # Read the file content into a string
    parse_result = file.read()
def remove_comments(python_code):
    # Regex pattern to match docstrings (triple quotes)
    docstring_pattern = r'("""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\')'
    
    # Regex pattern to match single-line comments, ensuring it's not inside a string
    comment_pattern = r'(?<!["\'])#.*'

    # Remove docstrings
    source_no_docstrings = re.sub(docstring_pattern, '', python_code, flags=re.MULTILINE)

    # Remove comments
    source_no_comments = re.sub(comment_pattern, '', source_no_docstrings)

    return source_no_comments

In [435]:
def mypy_type_to_ir_type(mypy_type: Optional[MypyType]) -> Optional[ObjectT]:
    """Convert mypy type to python type"""
    if mypy_type is None:
        # Handle missing type annotation
        # Return None or a default type to represent 'Any'
            return None  
    if isinstance(mypy_type, UnboundType):
        if mypy_type.optional:
            raise Exception("Optional type not supported")
        if mypy_type.name == "int":
            return Int
        if mypy_type.name == "bool":
            return Bool
        if mypy_type.name == "Any":
            # This means that no types are enforced
            return None
        elif mypy_type.name == "List" and len(mypy_type.args) == 1:
            return List[mypy_type_to_ir_type(mypy_type.args[0])]
        elif mypy_type.name == "Callable":
            if len(mypy_type.args) != 2:
                raise Exception("Callable type must have two arguments")
            if not isinstance(mypy_type.args[0], TypeList):
                raise Exception("First argument of Callable type must be a list")
            ret_type = mypy_type_to_ir_type(mypy_type.args[1])
            arg_types = (mypy_type_to_ir_type(arg) for arg in mypy_type.args[0].items)
            return Fn[pyTuple[(ret_type, *arg_types)]]
    elif isinstance(mypy_type, Instance):
        if mypy_type.type.fullname == "builtins.int":
            return Int
        elif mypy_type.type.fullname == "builtins.bool":
            return Bool
        elif mypy_type.type.fullname == "builtins.list":
            return List[mypy_type_to_ir_type(mypy_type.args[0])]
    elif isinstance(mypy_type, CallableType):
        arg_types = (mypy_type_to_ir_type(arg) for arg in mypy_type.arg_types)
        ret_type = mypy_type_to_ir_type(mypy_type.ret_type)
        return Fn[pyTuple[(ret_type, *arg_types)]]
    elif isinstance(mypy_type, AnyType):
        return None
    raise Exception(f"Unsupported type {mypy_type}")

def _get_func_def_ir_type(func_def: FuncDef) -> Type[Fn]:
    arg_types = [
        mypy_type_to_ir_type(arg.type_annotation if arg.type_annotation is not None else None)
        for arg in func_def.arguments
    ]
    # Check if func_def.type and func_def.type.ret_type exist
    ret_type_annotation = func_def.type.ret_type if func_def.type and func_def.type.ret_type else None
    ret_type = mypy_type_to_ir_type(ret_type_annotation)
    return Fn[pyTuple[(ret_type, *arg_types)]]


def _get_func_def_arg_names(func_def: FuncDef) -> pyList[str]:
    return [arg.variable.name for arg in func_def.arguments]

In [436]:
universal_imports = f"""
    from dsl_operators import *
    from typing import Any, Callable, List
    """
full_prog = dedent(remove_comments(dedent(universal_imports) + dedent(parse_result)))
def mypy_parse(full_prog, expected_num_funcs):
    # Parse the given Python code using mypy
    # Return the parsed MypyFile
    options = Options()
    options.incremental = False  # turn off caching of previously typed results
    options.export_types = True
    options.show_traceback = True
    options.python_version = PYTHON3_VERSION
    options.preserve_asts = True
    mypy_build = build.build(
        sources=[
            BuildSource(path=None, module="target_code", text=full_prog),
            BuildSource(path=None, module="dsl_operators")
        ],
        options=options,
    )
    target_tree: MypyFile = cast(MypyFile, mypy_build.graph["target_code"].tree)
    python_dsl_tree: MypyFile = cast(
        MypyFile, mypy_build.graph["dsl_operators"].tree
    )  
    
    target_func_defs = [
        func_def for func_def in target_tree.defs if isinstance(func_def, FuncDef)
    ]

    if len(target_func_defs) != expected_num_funcs:
        raise Exception(
            f"{expected_num_funcs} function definition expected but found {len(target_func_defs)}"
        )
    dsl_func_defs = [
        func_def for func_def in python_dsl_tree.defs if isinstance(func_def, FuncDef)
    ]
    func_sign = {
        func_def.name: (
            _get_func_def_ir_type(func_def),
            _get_func_def_arg_names(func_def),
        )
        for func_def in [*target_func_defs, *dsl_func_defs]
    }
    
    return target_func_defs, func_sign, mypy_build.types
    
x, y, z = mypy_parse(parse_result, 1)
print(x)
# print(y)
# print(z)

[<mypy.nodes.FuncDef object at 0x3855716c0>]


In [437]:
def mypy_node_to_ir(
    node: Node,
    func_sign: Dict[str, pyList[Union[Type, type]]],
    types: Dict[Node, MypyType],
    fn_decls: pyList[FnDeclRecursive],
    in_calls: pyList[pyTuple[str, str]],
) -> Expr:
    def parse_node(node: Node) -> Expr:
        print(node)
        if isinstance(node, FuncDef) or isinstance(node, LambdaExpr):
            if isinstance(node, FuncDef):
                func_ir_type, _ = func_sign[node.name]
            else:
                func_ir_type = mypy_type_to_ir_type(types[node])
            arg_ir_types = func_ir_type.argument_types(get_args(func_ir_type))
            ret_ir_type = func_ir_type.return_type(get_args(func_ir_type))
            variables: pyList[Object] = []
            for arg, ir_type in zip(node.arguments, arg_ir_types):
                variables.append(create_object(ir_type, arg.variable.name))
            return fn_decl_recursive(
                node.name,
                ret_ir_type,
                parse_node(node.body),
                *variables,
            )
        elif isinstance(node, Block):
            if len(node.body) == 2:
                if not isinstance(node.body[0], AssignmentStmt):
                    raise Exception(
                        "If there are two statements, the statement must be an assignment"
                    )
                if not isinstance(node.body[1], ReturnStmt):
                    raise Exception(
                        "If there are two statements, the second statement must be a return statement"
                    )

                first_stmt = cast(AssignmentStmt, node.body[0])
                second_stmt = cast(ReturnStmt, node.body[1])
                if len(first_stmt.lvalues) != 1:
                    raise Exception("Only one lvalue supported")
                if first_stmt.lvalues[0].name != second_stmt.expr.name:
                    raise Exception(
                        "Return variable must be the same as the variable assigned"
                    )

                return parse_node(first_stmt.rvalue)
            elif len(node.body) == 1:
                return parse_node(node.body[0])
            else:
                raise Exception("Only one or two statements supported")
        elif isinstance(node, ReturnStmt):
            return parse_node(node.expr)
        elif isinstance(node, CallExpr):
            if isinstance(node.callee, MemberExpr):
                raise Exception("Method calls not supported")
                # mypy_node_to_ir(node.callee, func_sign, types)
            elif isinstance(node.callee, NameExpr):
                func_name = cast(NameExpr, node.callee).name

                # First we check if this function is a python built-in function
                if func_name == "len":
                    if len(node.args) != 1:
                        raise Exception("len() takes exactly one argument")
                    arg_expr = parse_node(node.args[0])
                    if arg_expr.type.is_nested:
                        list_func_name = "matrix_length"
                    else:
                        list_func_name = "list_length"
                    return call(list_func_name, Int, arg_expr).src
                
                elif func_name == "list":
                    # Handle the 'list' built-in function
                    if len(node.args) == 0:
                        # Create an empty list
                        # Assuming default element type is Int; adjust as needed
                        list_type = List[Int]
                        return List.empty(Int).src
                    elif len(node.args) == 1:
                        # Create a list from an iterable
                        arg_expr = parse_node(node.args[0])
                        if is_list_type(arg_expr.type):
                            # If the argument is already a list, return it
                            return arg_expr.src
                        else:
                            # If the argument is an iterable, handle accordingly
                            # For simplicity, raise an exception
                            raise Exception("Creating a list from non-list iterables is not supported")
                    else:
                        raise Exception("list() takes at most one argument")
                elif func_name == "range":
                # Handle range()
                    num_args = len(node.args)
                    if num_args == 1:
                        # range(stop)
                        stop_expr = parse_node(node.args[0])
                        start_expr = Int(0)
                        step_expr = Int(1)
                    elif num_args == 2:
                        # range(start, stop)
                        start_expr = parse_node(node.args[0])
                        stop_expr = parse_node(node.args[1])
                        step_expr = Int(1)
                    elif num_args == 3:
                        # range(start, stop, step)
                        start_expr = parse_node(node.args[0])
                        stop_expr = parse_node(node.args[1])
                        stop_expr = parse_node(node.args[2])
                    else:
                        raise Exception("range() takes 1 to 3 arguments")

                    return List(call("range_list", List[Int], start_expr, stop_expr, stop_expr))
                
                
                if func_name not in func_sign.keys():
                    raise Exception(f"Unknown function {func_name}")

                func_ir_type, arg_names = func_sign[func_name]
                ret_ir_type = func_ir_type.return_type(get_args(func_ir_type))
                arguments_ir_types = func_ir_type.argument_types(get_args(func_ir_type))

                # Check number of arguments
                if len(node.args) != len(arguments_ir_types):
                    raise Exception(
                        f"Incorrect number of arguments. Required {len(func_sign[func_name])} but got {len(node.args)}"
                    )

                # Check argument types and make argument objects
                arg_exprs: List[Expr] = []
                for idx, (arg, expected_ir_type) in enumerate(
                    zip(node.args, arguments_ir_types)
                ):
                    arg_expr = parse_node(arg)
                    if arg_expr.type != expected_ir_type:
                        raise Exception(
                            f"Expected type {expected_ir_type} but got {arg_expr.type} for {idx}th argument of {func_name}"
                        )

                    # If the argument is a function, then we need to define another function for it
                    if is_fn_decl_type(arg_expr.type):
                        if isinstance(arg, LambdaExpr):
                            arg_fn_name = f"{arg_names[idx]}_{str(uuid.uuid4())[:8]}"
                            arg_expr.set_name(arg_fn_name)
                        elif isinstance(arg, NameExpr):
                            arg_fn_name = cast(NameExpr, arg).name
                        else:
                            raise Exception(
                                "Function argument must be a lambda expression or a function name"
                            )
                        fn_decls.append(arg_expr)
                        in_calls.append((func_name, arg_fn_name))
                    arg_exprs.append(arg_expr)

                # Check return type
                actual_ret_ir_type = mypy_type_to_ir_type(types.get(node))
                print(types.get(node))
                if actual_ret_ir_type != ret_ir_type:
                    raise Exception(
                        f"Expected return type {ret_ir_type} but got {actual_ret_ir_type} for {func_name}"
                    )
                return call(func_name, ret_ir_type, *arg_exprs).src
        elif isinstance(node, NameExpr):
            # Nothing can go wrong with a name expression (which are basically variables)
            ir_type = mypy_type_to_ir_type(types[node])
            return create_object(ir_type, node.name).src
        # TODO: check not
        elif isinstance(node, OpExpr):
            left = parse_node(node.left)
            right = parse_node(node.right)
            op = node.op
            if left.type is Int and right.type is Int:
                if op == "+":
                    return Add(left, right)
                elif op == "-":
                    return Sub(left, right)
                elif op == "*":
                    return Mul(left, right)
                elif op == "//":
                    return Div(left, right)
                elif op == "%":
                    return Mod(left, right)
                else:
                    raise Exception(f"Unsupported operator {op} on integers")
            elif (
                is_list_type(left.type)
                and not is_nested_list_type(left.type)
                and is_list_type(right.type)
                and not is_nested_list_type(right.type)
            ):
                if op == "+":
                    return call("list_concat", List[Int], left, right)
                else:
                    raise Exception(
                        f"Unsupported binary operation {op} on types {left.type} and {right.type}"
                    )
            elif left.type is Bool and right.type is Bool:
                if op == "and":
                    return And(left, right)
                elif op == "or":
                    return Or(left, right)
                else:
                    raise Exception(f"Unsupported operator {op} on booleans")
            else:
                raise Exception(
                    f"Unsupported binary operation {op} on types {left.type} and {right.type}"
                )
        elif isinstance(node, IntExpr):
            return create_object(Int, node.value).src
        elif isinstance(node, ListExpr):
            node_type = mypy_type_to_ir_type(types[node])
            return node_type.empty(get_args(node_type)[0]).src
        elif isinstance(node, ConditionalExpr):
            return ite(
                parse_node(node.cond),
                parse_node(node.if_expr),
                parse_node(node.else_expr),
            ).src
        elif isinstance(node, ComparisonExpr):
            if len(node.operators) != 1:
                raise Exception("Multiple operators not supported")
            if len(node.operands) != 2:
                raise Exception("Operation must be performed on exactly two operands")
            op = node.operators[0]
            left_expr = parse_node(node.operands[0])
            right_expr = parse_node(node.operands[1])
            if op == "==":
                if left_expr.type != right_expr.type:
                    raise Exception(
                        f"Comparison operator {op} only supported on objects of the same type"
                    )
                return Eq(left_expr, right_expr)
            else:
                if left_expr.type is not Int or right_expr.type is not Int:
                    raise Exception(
                        f"Comparison operator {op} only supported on integers"
                    )
                if op == ">":
                    return Gt(left_expr, right_expr)
                elif op == "<":
                    return Lt(left_expr, right_expr)
                elif op == ">=":
                    return Ge(left_expr, right_expr)
                elif op == "<=":
                    return Le(left_expr, right_expr)
                elif op == "!=":
                    return Not(Eq(left_expr, right_expr))
                else:
                    raise Exception(f"Unsupported operator {op}")
        elif isinstance(node, IndexExpr):
            base_expr = parse_node(node.base)
            base_object = create_object(base_expr.type, base_expr)
            if isinstance(node.index, SliceExpr):
                # Parse begin index
                if node.index.begin_index is None:
                    begin_index = None
                else:
                    begin_index = parse_node(node.index.begin_index)

                # Parse end index
                if node.index.end_index is None:
                    end_index = None
                else:
                    end_index = parse_node(node.index.end_index)

                if begin_index is None and end_index is not None:
                    return base_object[:end_index].src
                elif begin_index is not None and end_index is None:
                    return base_object[begin_index:].src
                elif begin_index is not None and end_index is not None:
                    return base_object[begin_index:end_index].src
                else:
                    raise Exception(f"Unsupported slice {node.index}")
            else:
                index_expr = parse_node(node.index)
                return base_object[index_expr].src
        else:
            raise Exception(f"Unsupported node {node}")

    ps_fn_decl = parse_node(node)
    fn_decls.append(ps_fn_decl)
    print(in_calls)

In [438]:
def check_solution(solution: str, expected_num_funcs: int) -> None:
    universal_imports = f"""
    from dsl_operators import *
    from typing import Any, Callable, List
    """
    full_prog = dedent(remove_comments(dedent(universal_imports) + dedent(solution)))
    target_func_defs, func_sigs, types = mypy_parse(full_prog, expected_num_funcs)
    fn_decls: pyList[FnDeclRecursive] = []
    in_calls: pyList[pyTuple[str, str]] = []
    for target_func_def in target_func_defs:
        mypy_node_to_ir(target_func_def, func_sigs, types, fn_decls, in_calls)
    return fn_decls, in_calls

In [439]:
# from tenspiler.llm.parser import check_solution
def parse(ps_sol):
    x, y = check_solution(ps_sol, 1)
    print (x)
    print(y)   
    return x, y

fn_d, in_calls = parse(parse_result)


FuncDef:4(
  kernel_2mm
  Args(
    Var(ni)
    Var(nj)
    Var(nk)
    Var(nl)
    Var(alpha)
    Var(beta)
    Var(tmp)
    Var(A)
    Var(B)
    Var(C)
    Var(D))
  def (ni: builtins.int, nj: builtins.int, nk: builtins.int, nl: builtins.int, alpha: builtins.int, beta: builtins.int, tmp: builtins.list[builtins.list[builtins.int]], A: builtins.list[builtins.list[builtins.int]], B: builtins.list[builtins.list[builtins.int]], C: builtins.list[builtins.list[builtins.int]], D: builtins.list[builtins.list[builtins.int]]) -> builtins.list[builtins.list[builtins.int]]
  Block:12(
    AssignmentStmt:12(
      NameExpr(return_var* [l])
      CallExpr:12(
        NameExpr(matrix_elemwise_add [dsl_operators.matrix_elemwise_add])
        Args(
          CallExpr:12(
            NameExpr(matrix_scalar_mul [dsl_operators.matrix_scalar_mul])
            Args(
              NameExpr(beta [l])
              NameExpr(D [l])))
          CallExpr:12(
            NameExpr(matrix_elemwise_mul [dsl_operato

TypeError: list expected at most 1 argument, got 2

In [417]:
def get_inv_sols_for_ps(n, ps_sol):
    prompt = prepare_prompt(remove_c_comments(ps_sol))
    ps_sols = []
    for _ in range(n):
        response = call_llm(prompt)
        ps_sol = response.text.strip()
        if ps_sol not in incorrect_ps_sols:
            ps_sols.append(ps_sol)

    return ps_sols

In [418]:
def verify(ps_sol, inv_sol):
    pass

In [419]:
from typing import List, Set

def transpile_code(source_code: str, num_iters: int, n: int) -> str:
    incorrect_ps_sols = set()
    seen_ps_sols = set()
    for _ in range(num_iters):
        ps_sols = get_ps_sols(n, source_code, incorrect_ps_sols)
        for ps_sol in ps_sols:
            if ps_sol in seen_ps_sols:
                continue
            if not parse(ps_sol):
                continue

            #start invarient for the PS
            seen_inv_sols_for_ps: set[str] = set()
            inv_sols: list[str] = get_inv_sols_for_ps(n, ps_sol)
            for inv_sol in inv_sols:
            # We have processed this INV for this PS before
                if inv_sol in seen_inv_sols_for_ps:
                    continue
                if not parse(inv_sol):
                    continue
                #verify the INV for the PS
                if verify(ps_sol, inv_sol):
                    return ps_sol, inv_sol
                seen_inv_sols_for_ps.add(inv_sol)

            incorrect_ps_sols.add(ps_sol)
            seen_ps_sols.add(ps_sol)

    return None
                

In [420]:
import textwrap
from typing import Any, Dict, Tuple, Union

from metalift.ir import (
    Add,
    And,
    Bool,
    Call,
    Div,
    Eq,
    Expr,
    FnDecl,
    FnDeclRecursive,
    Ge,
    Gt,
    Int,
    Ite,
    Le,
)
from metalift.ir import List as mlList
from metalift.ir import Lit, Lt, Matrix, Mod, Mul, Not, ObjectT, Or, Sub, Var
from tenspiler.codegen.utils import DataType
from tenspiler.tenspiler_common import (
    MAP_INT_TO_INT,
    MATRIX_ELEMWISE_ADD,
    MATRIX_ELEMWISE_DIV,
    MATRIX_ELEMWISE_MUL,
    MATRIX_ELEMWISE_SUB,
    MATRIX_SCALAR_ADD,
    MATRIX_SCALAR_DIV,
    MATRIX_SCALAR_MUL,
    MATRIX_SCALAR_SUB,
    SCALAR_MATRIX_DIV,
    SCALAR_MATRIX_SUB,
    SCALAR_VEC_DIV,
    SCALAR_VEC_SUB,
    VEC_ELEMWISE_ADD,
    VEC_ELEMWISE_DIV,
    VEC_ELEMWISE_MUL,
    VEC_ELEMWISE_SUB,
    VEC_SCALAR_ADD,
    VEC_SCALAR_DIV,
    VEC_SCALAR_MUL,
    VEC_SCALAR_SUB,
)

# Indentation is 4 spaces
INDENTATION = " " * 4


   
translations = {
    # VEC_ELEMWISE_ADD: lambda processed_args: f"torch.add({processed_args[0]}, {processed_args[1]})",
    # MATRIX_ELEMWISE_ADD: lambda processed_args: f"torch.add({processed_args[0]}, {processed_args[1]})",
    # VEC_SCALAR_ADD: lambda processed_args: f"torch.add({processed_args[0]}, {processed_args[1]})",
    # MATRIX_SCALAR_ADD: lambda processed_args: f"torch.add({processed_args[0]}, {processed_args[1]})",
    VEC_ELEMWISE_ADD: lambda processed_args: f"({processed_args[0]}) + ({processed_args[1]})",
    MATRIX_ELEMWISE_ADD: lambda processed_args: f"({processed_args[0]}) + ({processed_args[1]})",
    VEC_SCALAR_ADD: lambda processed_args: f"({processed_args[0]}) + ({processed_args[1]})",
    MATRIX_SCALAR_ADD: lambda processed_args: f"({processed_args[0]}) + ({processed_args[1]})",
    # VEC_ELEMWISE_SUB: lambda processed_args: f"torch.sub({processed_args[0]}, {processed_args[1]})",
    # MATRIX_ELEMWISE_SUB: lambda processed_args: f"torch.sub({processed_args[0]}, {processed_args[1]})",
    # SCALAR_VEC_SUB: lambda processed_args: f"torch.sub({processed_args[0]}, {processed_args[1]})",
    # SCALAR_MATRIX_SUB: lambda processed_args: f"torch.sub({processed_args[0]}, {processed_args[1]})",
    VEC_ELEMWISE_SUB: lambda processed_args: f"({processed_args[0]}) - ({processed_args[1]})",
    MATRIX_ELEMWISE_SUB: lambda processed_args: f"({processed_args[0]}) - ({processed_args[1]})",
    SCALAR_VEC_SUB: lambda processed_args: f"({processed_args[0]}) - ({processed_args[1]})",
    SCALAR_MATRIX_SUB: lambda processed_args: f"({processed_args[0]}) - ({processed_args[1]})",
    # VEC_SCALAR_SUB: lambda processed_args: f"torch.sub({processed_args[1]}, {processed_args[0]})",
    # MATRIX_SCALAR_SUB: lambda processed_args: f"torch.sub({processed_args[1]}, {processed_args[0]})",
    VEC_SCALAR_SUB: lambda processed_args: f"({processed_args[1]}) - ({processed_args[0]})",
    MATRIX_SCALAR_SUB: lambda processed_args: f"({processed_args[1]}) - ({processed_args[0]})",
    # VEC_ELEMWISE_MUL: lambda processed_args: f"torch.multiply({processed_args[0]}, {processed_args[1]})",
    # MATRIX_ELEMWISE_MUL: lambda processed_args: f"torch.multiply({processed_args[0]}, {processed_args[1]})",
    # VEC_SCALAR_MUL: lambda processed_args: f"torch.multiply({processed_args[0]}, {processed_args[1]})",
    # MATRIX_SCALAR_MUL: lambda processed_args: f"torch.multiply({processed_args[0]}, {processed_args[1]})",
    VEC_ELEMWISE_MUL: lambda processed_args: f"({processed_args[0]}) * ({processed_args[1]})",
    MATRIX_ELEMWISE_MUL: lambda processed_args: f"({processed_args[0]}) * ({processed_args[1]})",
    VEC_SCALAR_MUL: lambda processed_args: f"({processed_args[0]}) * ({processed_args[1]})",
    MATRIX_SCALAR_MUL: lambda processed_args: f"({processed_args[0]}) * ({processed_args[1]})",
    # VEC_ELEMWISE_DIV: lambda processed_args: f"torch.divide({processed_args[0]}, {processed_args[1]})",
    # MATRIX_ELEMWISE_DIV: lambda processed_args: f"torch.divide({processed_args[0]}, {processed_args[1]})",
    # SCALAR_VEC_DIV: lambda processed_args: f"torch.divide({processed_args[0]}, {processed_args[1]})",
    # SCALAR_MATRIX_DIV: lambda processed_args: f"torch.divide({processed_args[0]}, {processed_args[1]})",
    VEC_ELEMWISE_DIV: lambda processed_args, is_floor: f"({processed_args[0]}) // ({processed_args[1]})"
    if is_floor
    else f"({processed_args[0]}) / ({processed_args[1]})",
    MATRIX_ELEMWISE_DIV: lambda processed_args, is_floor: f"({processed_args[0]}) // ({processed_args[1]})"
    if is_floor
    else f"({processed_args[0]}) / ({processed_args[1]})",
    SCALAR_VEC_DIV: lambda processed_args, is_floor: f"({processed_args[0]}) // ({processed_args[1]})"
    if is_floor
    else f"({processed_args[0]}) / ({processed_args[1]})",
    SCALAR_MATRIX_DIV: lambda processed_args, is_floor: f"({processed_args[0]}) // ({processed_args[1]})"
    if is_floor
    else f"({processed_args[0]}) / ({processed_args[1]})",
    # VEC_SCALAR_DIV: lambda processed_args: f"torch.divide({processed_args[1]}, {processed_args[0]})",
    # MATRIX_SCALAR_DIV: lambda processed_args: f"torch.divide({processed_args[1]}, {processed_args[0]})",
    VEC_SCALAR_DIV: lambda processed_args, is_floor: f"({processed_args[1]}) // ({processed_args[0]})"
    if is_floor
    else f"({processed_args[1]}) / ({processed_args[0]})",
    MATRIX_SCALAR_DIV: lambda processed_args, is_floor: f"({processed_args[1]}) // ({processed_args[0]})"
    if is_floor
    else f"({processed_args[1]}) / ({processed_args[0]})",
    "matrix_vec_mul": lambda processed_args: f"torch.matmul({processed_args[0]}, {processed_args[1]})",
    "list_eq": lambda processed_args: f"torch.all(torch.eq({processed_args[0]}, {processed_args[1]}))",
    "list_empty": lambda processed_args: f"torch.empty(0)",
    "matrix_empty": lambda processed_args: f"torch.empty(0, 0)",
    "list_get": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}]",
    "matrix_get": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}]",
    "list_append": lambda processed_args: f"torch.cat([{processed_args[0]}, {processed_args[1]}.unsqueeze(0)], dim=0)",
    "matrix_append": lambda processed_args: f"torch.cat([{processed_args[0]}, {processed_args[1]}.unsqueeze(0)], dim=0)",
    "list_prepend": lambda processed_args: f"torch.cat([{processed_args[1].unsqueeze(0)}, {processed_args[0]}], dim=0)",
    "matrix_prepand": lambda processed_args: f"torch.cat([{processed_args[1].unsqueeze(0)}, {processed_args[0]}], dim=0)",
    "list_concat": lambda processed_args: f"torch.cat([{processed_args[0]}, {processed_args[1]}], dim=0)",
    "list_tail": lambda processed_args: f"{processed_args[0]}[:{processed_args[1]}]",
    "matrix_tail": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}:]",
    "list_take": lambda processed_args: f"{processed_args[0]}[:{processed_args[1]}]",
    "matrix_take": lambda processed_args: f"{processed_args[0]}[:{processed_args[1]}]",
    "vec_slice": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}:{processed_args[2]}]",
    "matrix_row_slice": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}:{processed_args[2]}]",
    "vec_slice_with_length": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}:{processed_args[1]} + {processed_args[2]}]",
    "matrix_row_slice_with_length": lambda processed_args: f"{processed_args[0]}[{processed_args[1]}:{processed_args[1]} + {processed_args[2]}]",
    "matrix_col_slice": lambda processed_args: f"{processed_args[0]}[:, {processed_args[1]}:{processed_args[2]}]",
    "matrix_col_slice_with_length": lambda processed_args: f"{processed_args[0]}[:, {processed_args[1]}:{processed_args[1]} + {processed_args[2]}]",
    "list_length": lambda processed_args: f"{processed_args[0]}.size(dim=0)",
    "matrix_length": lambda processed_args: f"{processed_args[0]}.size(dim=0)",
    "matrix_transpose": lambda processed_args: f"torch.transpose({processed_args[0]}, 0, 1)",
    "reduce_max": lambda processed_args: f"torch.max({processed_args[0]})",
    "reduce_sum": lambda processed_args: f"torch.sum({processed_args[0]})",
    "reduce_mul": lambda processed_args: f"torch.prod({processed_args[0]})",
    "integer_sqrt": lambda processed_args, is_list=False: f"torch.sqrt(torch.as_tensor({processed_args[0]}))",
    "integer_exp": lambda processed_args, is_list=False: f"torch.exp(torch.as_tensor({processed_args[0]}))",
    # Add: lambda processed_args, is_int: f"{processed_args[0]} + {processed_args[1]}" if is_int else f"torch.add({processed_args[0]}, {processed_args[1]})",
    # Sub: lambda processed_args, is_int: f"{processed_args[0]} - {processed_args[1]}" if is_int else f"torch.sub({processed_args[0]}, {processed_args[1]})",
    # Mul: lambda processed_args, is_int: f"{processed_args[0]} * {processed_args[1]}" if is_int else f"torch.multiply({processed_args[0]}, {processed_args[1]})",
    # Div: lambda processed_args, is_int: f"{processed_args[0]} // {processed_args[1]}" if is_int else f"torch.divide({processed_args[0]}, {processed_args[1]})",
    Add: lambda processed_args, is_int: f"({processed_args[0]}) + ({processed_args[1]})",
    Sub: lambda processed_args, is_int: f"({processed_args[0]}) - ({processed_args[1]})",
    Mul: lambda processed_args, is_int: f"({processed_args[0]}) * ({processed_args[1]})",
    Div: lambda processed_args, is_int: f"({processed_args[0]}) // ({processed_args[1]})",
    "float_div": lambda processed_args: f"({processed_args[0]}) / ({processed_args[1]})",
    Mod: lambda processed_args, is_int: f"({processed_args[0]}) % ({processed_args[1]})",
    Eq: lambda processed_args, is_int: f"{processed_args[0]} == {processed_args[1]}"
    if is_int
    else f"torch.eq({processed_args[0]}, {processed_args[1]})",
    Gt: lambda processed_args, is_int: f"{processed_args[0]} > {processed_args[1]}"
    if is_int
    else f"torch.greater({processed_args[0]}, {processed_args[1]})",
    Ge: lambda processed_args, is_int: f"{processed_args[0]} >= {processed_args[1]}"
    if is_int
    else f"torch.greater_equal({processed_args[0]}, {processed_args[1]})",
    Lt: lambda processed_args, is_int: f"{processed_args[0]} < {processed_args[1]}"
    if is_int
    else f"torch.less({processed_args[0]}, {processed_args[1]})",
    Le: lambda processed_args, is_int: f"{processed_args[0]} <= {processed_args[1]}"
    if is_int
    else f"torch.less_equal({processed_args[0]}, {processed_args[1]})",
    Not: lambda processed_args, is_prim: f"not {processed_args[0]}"
    if is_prim
    else f"torch.logical_not({processed_args[0]})",
    And: lambda processed_args, is_prim: f"({processed_args[0]}) and ({processed_args[1]})"
    if is_prim
    else f"torch.logical_and({processed_args[0]}, {processed_args[1]})",
    Or: lambda processed_args, is_prim: f"({processed_args[0]}) or ({processed_args[1]})"
    if is_prim
    else f"torch.logical_or({processed_args[0]}, {processed_args[1]})",
}


def pytorch_codegen(
    ps_fn_decl: Union[FnDecl, FnDeclRecursive],
    all_synthesized_fns: Dict[str, Expr],
    d_type: DataType = DataType.FLOAT,
) -> str:
    has_matmul = False

    def helper(expr: Any, vars_to_replace: Dict[str, Expr] = {}) -> Tuple[str, ObjectT]:
        nonlocal has_matmul
        if not isinstance(expr, Expr):
            return str(expr), None
        if isinstance(expr, Call):
            processed_args = [
                helper(arg, vars_to_replace)[0] for arg in expr.arguments()
            ]
            fn_name = expr.name()

            if fn_name == "matrix_vec_mul":
                has_matmul = True

            if fn_name.endswith("matrix_selection_two_args"):
                for name, fn in all_synthesized_fns.items():
                    if name.endswith("select_two_args"):
                        select_two_args_fn_decl = fn
                if select_two_args_fn_decl is None:
                    raise ValueError("select_two_args not found")
                select_two_args_body = select_two_args_fn_decl.body()
                cond, if_then, if_else = (
                    select_two_args_body.c(),
                    select_two_args_body.e1(),
                    select_two_args_body.e2(),
                )
                select_args = select_two_args_fn_decl.arguments()[:2]
                matrix_args = expr.arguments()[:2]
                vars_to_replace: Dict[str, Expr] = {}
                for i in range(2):
                    vars_to_replace[select_args[i].name()] = matrix_args[i]

                cond_res, cond_type = helper(cond, vars_to_replace)
                if_then_res = helper(if_then, vars_to_replace)[0]
                if_else_res = helper(if_else, vars_to_replace)[0]
                res = f"torch.where({cond_res}, {if_then_res}, {if_else_res})"
                if cond_type == Bool:
                    res = f"torch.where(torch.tensor({cond_res}), {if_then_res}, {if_else_res})"
                return (
                    res,
                    expr.type,
                )
            elif fn_name == MAP_INT_TO_INT or fn_name == "vec_map":
                map_fn_name = all_synthesized_fns[MAP_INT_TO_INT].body().name()
                if map_fn_name in {"integer_sqrt", "integer_exp"}:
                    return (
                        translations[map_fn_name](processed_args, fn_name == "vec_map"),
                        expr.type,
                    )
                else:
                    raise ValueError(f"Unknown map function name: {map_fn_name}")
            elif fn_name in translations.keys():
                if fn_name in {
                    VEC_ELEMWISE_DIV,
                    MATRIX_ELEMWISE_DIV,
                    SCALAR_VEC_DIV,
                    SCALAR_MATRIX_DIV,
                    VEC_SCALAR_DIV,
                    MATRIX_SCALAR_DIV,
                }:
                    return (
                        translations[fn_name](processed_args, d_type != DataType.FLOAT),
                        expr.type,
                    )
                return translations[fn_name](processed_args), expr.type
            elif fn_name in all_synthesized_fns.keys():
                return helper(all_synthesized_fns[fn_name].body())

            raise Exception(f"Unknown function name: {fn_name}")

        # Ite expression. Some condition are constants
        if isinstance(expr, Ite):
            cond = helper(expr.c())[0]

            if cond == "True":
                return helper(expr.e1(), vars_to_replace)
            elif cond == "False":
                return helper(expr.e2(), vars_to_replace)
            else:
                return (
                    f"{helper(expr.e1(), vars_to_replace)[0]} if {cond} else {helper(expr.e2(), vars_to_replace)[0]}",
                    expr.e1().type,
                )

        # Arithmetic operations
        processed_args = [helper(arg, vars_to_replace) for arg in expr.args]
        processed_args_types = [a[1] for a in processed_args]
        processed_args = [a[0] for a in processed_args]
        if any(isinstance(expr, cls) for cls in [Add, Sub, Mul, Div, Mod]):
            is_arg_type_int = all([a_type is Int for a_type in processed_args_types])
            ret_type = (
                Int
                if is_arg_type_int
                else [
                    a_type
                    for a_type in processed_args_types
                    if a_type is not Int and a_type is not None
                ][0]
            )
            if isinstance(expr, Div) and d_type == DataType.FLOAT:
                return translations["float_div"](processed_args), ret_type
            return translations[type(expr)](processed_args, is_arg_type_int), ret_type

        # Relational operations
        elif any(isinstance(expr, cls) for cls in [Gt, Ge, Eq, Lt, Le]):
            is_arg_type_int = all([a_type is Int for a_type in processed_args_types])
            ret_type = Bool if is_arg_type_int else mlList[Bool]
            return translations[type(expr)](processed_args, is_arg_type_int), ret_type
        elif any(isinstance(expr, cls) for cls in [And, Or, Not]):
            is_arg_type_prim = all(
                [a_type is Int or a_type is Bool for a_type in processed_args_types]
            )
            ret_type = Bool if is_arg_type_prim else mlList[Bool]
            return translations[type(expr)](processed_args, is_arg_type_prim), ret_type

        # Other
        elif isinstance(expr, Lit):
            return f"{expr.val()}", expr.type
        elif isinstance(expr, Var):
            if expr.name() in vars_to_replace:
                return helper(vars_to_replace[expr.name()], vars_to_replace)
            return expr.name(), expr.type
        return str(expr)

    ###############################
    # Begins actual code generation
    ###############################
    import_stmt = """
####### import statements ########
import torch
"""
    print(import_stmt)

    fn_name = f"{ps_fn_decl.name()[:-3]}"
    arguments = [arg.name() for arg in ps_fn_decl.arguments()]
    arguments_str = ", ".join(arguments)
    kernel_name = f"{fn_name}_torch"
    print("####### kernel code ########")
    kernel_fn = f"""
    def {kernel_name} ({arguments_str}):
        return {helper(ps_fn_decl.body())[0]}
    """
    kernel_fn = textwrap.dedent(kernel_fn)
    print(kernel_fn)

    print("####### glued code ########")
    glued_name = f"{fn_name}_torch_glued "
    argument_types = [arg.type for arg in ps_fn_decl.arguments()]

    conversions = []
    for i in range(len(arguments)):
        if argument_types[i] == List[List[Int]] or argument_types[i] == mlList[Int]:
            lib_dtype = "torch.uint8"
            if d_type == DataType.FLOAT:
                lib_dtype = "torch.float32"

            if d_type == DataType.INT32:
                lib_dtype = "torch.int32"

            # matmul require float
            if has_matmul:
                lib_dtype = "torch.float32"

            conversions.append(
                f"{arguments[i]} = torch.tensor({arguments[i]}, dtype={lib_dtype})"
            )

    arg_processing = f"\n{INDENTATION * 2}".join(conversions)
    glued_fn = f"""
    def {glued_name}({arguments_str}):
        {arg_processing}
        return {kernel_name}({arguments_str})
    """
    glued_fn = textwrap.dedent(glued_fn)
    print(glued_fn)

    return import_stmt + kernel_fn + glued_fn

all_synthesized_fns = {fn.name(): fn for fn in fn_d}

print(all_synthesized_fns)
print(pytorch_codegen(fn_d[0], all_synthesized_fns, DataType.FLOAT))

{'kernel_2mm': (FnDeclRecursive:Function  kernel_2mm (matrix_elemwise_add:List List Int (matrix_scalar_mul:List List Int beta D) (matrix_elemwise_mul:List List Int (matrix_elemwise_mul:List List Int (matrix_scalar_mul:List List Int alpha A) B) C)) ni nj nk nl alpha beta tmp A B C D)}

####### import statements ########
import torch

####### kernel code ########

def kernel__torch (ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D):
    return ((beta) * (D)) + ((((alpha) * (A)) * (B)) * (C))

####### glued code ########

def kernel__torch_glued (ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D):

    return kernel__torch(ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D)


####### import statements ########
import torch

def kernel__torch (ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D):
    return ((beta) * (D)) + ((((alpha) * (A)) * (B)) * (C))

def kernel__torch_glued (ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D):

    return kernel__torch(ni, nj, nk, nl, alpha, beta, tmp, A, B, C, D)

