In [1]:
from __future__ import annotations
from exo import proc, DRAM, Procedure
import exo.query_asts as exo_ast
import exo.stdlib.scheduling as sched
import multiprocessing as mp

import sys
from itertools import chain
from functools import reduce
from typing import Any

from train import BASE_EXO, evaluate_out

sys.setrecursionlimit(10000)

CONTROL_STRUCTURE = (exo_ast.For, exo_ast.If, exo_ast.Proc)

In [2]:
from pprint import pprint
import cProfile

def profileit(do_print: bool = False):
    def inner(func):
        def wrapper(*args, **kwargs):
            prof = cProfile.Profile()
            retval = prof.runcall(func, *args, **kwargs)

            if do_print:
                prof.print_stats(sort="cumtime")
            else:
                prof.dump_stats(f"{func.__name__}.profile")

            return retval
        return wrapper
    return inner

### Query AST Classes
---
Exo ast nodes are organized into a hierarchy of dataclasses:

- QueryAST
  - `Proc`  ( name : str, args : list[FnArg], assertions : list[Expr],
                        body : list[Stmt],  instruction : Optional[str] )
  - `Stmt`
    - `Assign`    ( name : str,   lhs_type : Type,    idx : list[Expr],
                              rhs  : Expr )
    - `Reduce`    ( name : str,   lhs_type : Type,    idx : list[Expr],
                              rhs  : Expr )
    - `WriteConfig` ( config : Config, field : str,
                              rhs  : Expr )
    - `Pass`      ()
    - `If`        ( cond : Expr,  body : list[Stmt],  orelse : list[Stmt] )
    - `For`       ( name : str,   lo   : Expr,        hi : Expr,
                              body : list[Stmt],  is_par : bool )
    - `Alloc`     ( name : str,   type : Type,        memory : Optional[Memory] )
    - `Call`      ( proc : str,   args : list[Expr] )
    - `WindowStmt`( name : str,   rhs  : Expr )
  - `Expr`
    - `Read`    ( name : str,   idx  : list[Expr],    type : Type )
    - `Const`   ( val  : Any,   type : Type )
    - `USub`    ( arg  : Expr,  type : Type )
    - `BinOp`   ( op   : str,   lhs  : Expr,
                            rhs  : Expr,          type : Type )
    - `BuiltIn` ( func : str,   args : list[Expr],    type : Type )
    - `WindowExpr`( name : str, idx : list[WAccess],  type : Type )
    - `StrideExpr`( name : str, dim : int,            type : Type )
    - `ReadConfig`( config : Config, field : str,     type : Type )
  - `WAccess`
    - `Interval`( lo : Expr, hi : Expr )
    - `Point`( pt : Expr )
  - `FnArg` ( name : str, type : Type, memory : Optional[Memory] )
  - `Type`
    - `R`()
    - `f16`()
    - `f32`()
    - `f64`()
    - `i8`()
    - `i32`()
    - `bool`()
    - `int`()
    - `index`()
    - `size`()
    - `stride`()
    - `tensor`( hi : list[Expr], is_window : bool, type : Type )

### Code

In [3]:
def get_exo_proc() -> Procedure:
    @proc
    def generated_operation(
        In: i8[16, 16] @ DRAM,
        Weights: i8[16, 16] @ DRAM,
        Out: i8[16, 16] @ DRAM,
    ):
        for i in seq(0, 16):
            for j in seq(0, 16):
                for k in seq(0, 16):
                    Out[i, j] += In[i, k] * Weights[k, j]

    return generated_operation

In [4]:
def extract_c_code(proc: Procedure) -> str:
    return proc.c_code_str().split("#include <stdlib.h>")[-1]

In [5]:
PROC_AST: exo_ast.Proc = get_exo_proc().get_ast() # pyright: ignore[reportAssignmentType]

In [6]:
pprint(PROC_AST.body)

[For(name='i',
     lo=Const(val=0, type=int()),
     hi=Const(val=16, type=int()),
     body=[For(name='j',
               lo=Const(val=0, type=int()),
               hi=Const(val=16, type=int()),
               body=[For(name='k',
                         lo=Const(val=0, type=int()),
                         hi=Const(val=16, type=int()),
                         body=[Reduce(name='Out',
                                      lhs_type=i8(),
                                      idx=[Read(name='i', idx=[], type=index()),
                                           Read(name='j',
                                                idx=[],
                                                type=index())],
                                      rhs=BinOp(op='*',
                                                lhs=Read(name='In',
                                                         idx=[Read(name='i',
                                                                   idx=[],
                    

In [7]:
def extract_fors(mod: exo_ast.QueryAST) -> list[exo_ast.For]:
    """Produce a list of the For nodes in the ast.Module in depth-wise order."""
    if not isinstance(mod, CONTROL_STRUCTURE):
        return [ ]

    match mod:
        case exo_ast.For(_):
            with mp.Pool() as pool:
                return list(chain([mod], 
                    *pool.starmap(extract_fors, [(item, ) for item in mod.body])
                ))
        case exo_ast.Proc(_) | exo_ast.If(_):
            with mp.Pool() as pool:
                return list(chain(
                    *pool.starmap(extract_fors, [(item, ) for item in mod.body]
                )))
        case _:
            raise ValueError

extract_fors(PROC_AST)

AssertionError: daemonic processes are not allowed to have children

In [None]:
union = lambda iter_of_sets: reduce(lambda a, b: a | b, iter_of_sets,)

In [None]:
def get_swappable_loops(ast):
    """Return a set of all swappable for loops"""
    if not isinstance(ast, CONTROL_STRUCTURE):
        return set([])

    match ast:
        case exo_ast.For(name=name, body=[exo_ast.For(name=inner_name)]):
            with mp.Pool() as pool:
                return set([(name, inner_name)]) \
                    | union(pool.starmap(get_swappable_loops, 
                        [ (node,) for node in ast.body ]
                    ))
        case exo_ast.Proc(body=b) | exo_ast.If(body=b) | exo_ast.For(body=b):
            with mp.Pool() as pool:
                return union(pool.starmap(get_swappable_loops, 
                    [ (node,) for node in b ] 
                ))
        case _:
            raise ValueError

get_swappable_loops(PROC_AST)

{('i', 'j'), ('j', 'k')}

In [None]:
class ScheduledProc:
    def __init__(self, exo_sched_func, args: list = [], kwargs: dict[str, Any] = {}):
        self.sched_func = exo_sched_func
        self.func_args = args
        self.func_kwargs = kwargs

    def _apply_transform(self, proc: Procedure) -> Procedure:
        return self.sched_func(
            proc, 
            *self.func_args,
            **self.func_kwargs
        )

    def __repr__(self):
        return (
            f"{self.sched_func.__name__}(<proc>, "
          + ", ".join(self.func_args)
          + ", ".join([f"{key} = {val}" for key, val in self.func_kwargs.items()])
          + ")"
        )

    def __call__(self, proc: Procedure) -> Procedure:
        return self._apply_transform(proc)

In [None]:
def get_loop_reorders(ast: exo_ast.Proc) -> list[ScheduledProc]:
    loop_pairs = get_swappable_loops(ast)
    return [
        ScheduledProc(
            sched.reorder_loops,
            [" ".join(list(loop_pair))], 
            {}
        ) for loop_pair in loop_pairs
    ]

get_loop_reorders(PROC_AST)

[reorder_loops(<proc>, i j), reorder_loops(<proc>, j k)]

In [None]:
def get_loop_unrolls(ast: exo_ast.Proc) -> list[ScheduledProc]:

    def find_innermost_loops(mod: exo_ast.QueryAST) -> list[str]:
        if not isinstance(mod, CONTROL_STRUCTURE):
            return [ ]
        match mod:
            case exo_ast.For(body=[exo_ast.For(_)]) | exo_ast.Proc(_) | exo_ast.If(_):
                return list(chain(*[
                    find_innermost_loops(item)
                    for item in mod.body
                ]))
            case exo_ast.For(name=name):
                return [name]
            case _:
                raise ValueError

    return [
        ScheduledProc(sched.unroll_loop, [loop_name])
        for loop_name in find_innermost_loops(ast)
    ]

get_loop_unrolls(PROC_AST)

[unroll_loop(<proc>, k)]

In [None]:
def get_loop_splits(ast: exo_ast.Proc) -> list[ScheduledProc]:
    ...

In [None]:
def get_loop_fusions(ast: exo_ast.Proc):
    ...

In [None]:
def get_all_exo_tsfs(proc: Procedure):
    root_ast: exo_ast.Proc = proc.get_ast() # pyright: ignore[reportAssignmentType]
    tsfs: list[ScheduledProc] = [ ]

    for expansion in [
        get_loop_reorders,
        get_loop_unrolls,
    ]:
        print(f"\t\t> Doing expansion: {expansion.__name__}")
        tsfs.extend(expansion(root_ast))

    tsfs.append(ScheduledProc(sched.simplify))

    return tsfs

p = get_exo_proc()
tsf = get_all_exo_tsfs(get_exo_proc())[0]
tsf(p)

		> Doing expansion: get_loop_reorders
		> Doing expansion: get_loop_unrolls


```python
def generated_operation(In: i8[16, 16] @ DRAM, Weights: i8[16, 16] @ DRAM,
                        Out: i8[16, 16] @ DRAM):
    for j in seq(0, 16):
        for i in seq(0, 16):
            for k in seq(0, 16):
                Out[i, j] += In[i, k] * Weights[k, j]
```

In [None]:
def evaluate(c_code: str) -> float:
    return 1.

@profileit(do_print=True)
def search_procs(proc: Procedure):
    og_c_code = extract_c_code(proc)
    seen_procs = {
        og_c_code : (evaluate(og_c_code), proc)
    }

    def deeper(p: Procedure):
        nonlocal seen_procs

        new_procs = [tsf(p) for tsf in get_all_exo_tsfs(p)]
        print(f"Found {len(new_procs)} potential procs")
        actually_new_procs = []

        # here is where we would add the secret sauce:
        #new_procs = sort_by_fancy_heuristic(new_procs)
        
        for new_proc in new_procs:

            c_code = extract_c_code(new_proc)
        
            if c_code in seen_procs.keys():
                continue

            print(f"- Adding a newly found proc")
            seen_procs[c_code] = (evaluate(c_code), new_proc)
            actually_new_procs.append(new_proc)
        
        print(f"branching into {len(actually_new_procs)} new procs")
        for new_proc in actually_new_procs:
            print(f"  - entering a recursive call for \n{str(new_proc)[:int(1e3)]}")
            deeper(new_proc)
            print(f"  - exiting a recursive call")
    
    deeper(proc)

    return list(seen_procs.values())

procs = search_procs(get_exo_proc())
procs, len(procs)

		> Doing expansion: get_loop_reorders
		> Doing expansion: get_loop_unrolls
Found 4 potential procs
- Adding a newly found proc
- Adding a newly found proc
- Adding a newly found proc
branching into 3 new procs
  - entering a recursive call for 
def generated_operation(In: i8[16, 16] @ DRAM, Weights: i8[16, 16] @ DRAM,
                        Out: i8[16, 16] @ DRAM):
    for j in seq(0, 16):
        for i in seq(0, 16):
            for k in seq(0, 16):
                Out[i, j] += In[i, k] * Weights[k, j]
		> Doing expansion: get_loop_reorders
		> Doing expansion: get_loop_unrolls
Found 4 potential procs
- Adding a newly found proc
- Adding a newly found proc
branching into 2 new procs
  - entering a recursive call for 
def generated_operation(In: i8[16, 16] @ DRAM, Weights: i8[16, 16] @ DRAM,
                        Out: i8[16, 16] @ DRAM):
    for j in seq(0, 16):
        for k in seq(0, 16):
            for i in seq(0, 16):
                Out[i, j] += In[i, k] * Weights[k, j]
		> 

KeyboardInterrupt: 

# Exo's scheduling API


## Top-level Python function decorator

1. `@proc` - decorates a Python function which is parsed and compiled as Exo. Replaces
   the function with a `Procedure` object.
2. `@instr` - same as `@proc`, but accepts a hardware instruction as a format string.
3. `@config` - decorates a Python class which is parsed and compiled as an Exo
   configuration object



## Procedure object methods



**Introspection operations**

- `.name()` returns the procedure name.
- `.check_effects()` forces Exo to run effect checking on the procedure.
- `.show_effects()` prints the effects of the procedure.
- `.show_effect(stmt)` prints the effect of the `stmt` in the procedure.
- `.is_instr()` returns `true` if the procedure has a hardware instruction string.
- `.get_instr()` returns the hardware instruction string.
- `.get_ast()` returns a `QAST`, which is an AST representation suitable for
  introspection.



**Execution / interpretation operations**

- `.compile_c(directory, filename)` compiles the procedure into C and stores
  in `filename` in the `directory`.
- `.interpret(**args)` runs Exo interpreter on the procedure.



## Scheduling operations on Procedure objects



### Loop related operations



- `.split(loop, split_const, iter_vars, tail='guard', perfect=False)`
  - Splits `loop` into an outer and an inner loop. The inner loop bound is `split_const` and the outer and inner loop names are specified by a list of strings `iter_vars`. If `perfect` is True, it will not introduce a tail case. `tail` specifies the tail strategies, where the options are `guard`, `cut`, and `cut_and_guard`.
- `.fuse_loop(loop1, loop2)`
  - Fuses two adjacent loops with a common iteration variable.
- `.partition_loop(loop, num)`
  - Partitions `loop` into two loops, the first running between `0` and `num` and the second between `num+1` and `loop`'s original bound.
- `.reorder(loop1, loop2)`
  - Reorders two nested loops. `loop2` should be nested directly inside `loop1`. `loop1` will be nested inside `loop2` after this operation.
- `.unroll(loop)`
  - Unrolls the loop. The loop needs to have a constant bound.
- `.fission_after(stmt, n_lifts=1)`
  - Fissions the `n_lifts` number of loops around the `stmt`. The fissioned loops around the `stmt` need to be directly nested with each other and the statements before and after the `stmt` should not have any allocation dependencies.
- `.remove_loop(loop)`
  - Replaces the loop with its body if the body is idempotent. The system must be able to prove that the loop runs at least once.



### Buffer related operations



- `.reuse_buffer(buf1, buf2` 
  - Reuses a buffer `buf1` in the use site of `buf2` and removes the allocation of `buf2`
- `.inline_window(win_stmt)`
  - Removes the window statement `win_stmt`, which is an alias to the window, and inlines the windowing in its use site
- `.expand_dim(stmt, alloc_dim, indexing)`
  - Expands the dimension of the allocation statement `stmt` with dimension `alloc_dim` of indexing `indexing`
- `.bind_expr(new_name, expr)`
  - Binds the right hand side expression `expr` to a newly allocated buffer named `new_name`
- `.stage_mem(win_expr, new_name, stmt_start, stmt_end=None)`
  - Stages the buffer `win_expr` to the new window expression `new_name` in statement block (`stmt_start` to `stmt_end`), and adds an initialization loop and a write-back loop
- `.rearrange_dim(alloc, dimensions)`
  - Takes an allocation statement and a list of integers to map the dimension. It rearranges the dimensions of `alloc` in `dimension` order. E.g., if `alloc` were `foo[N,M,K]` and the `dimension` were `[2,0,1]`, it would become `foo[K,N,M]` after this operation.
- `.lift_alloc(alloc, n_lifts=1, keep_dims=False)`
  - Lifts the allocation statement `alloc` out of `n_lifts` number of scopes. If and For statements are the only statements in Exo which introduce a scope. When lifting the allocation out of a for loop, it will expand its dimension to the loop bound if `keep_dims` is True.



### Config related operations



- `.bind_config(expr, config, field)`
  - Binds the right hand side `expr` to `config.field`. It will replace the use site of `expr` with `config.field` and introduces a config statement of `config.field = expr`.
- `.configwrite_root(config, field, expr)`
  - Inserts the config statement `config.field = expr` in the beginning of the procedure.
- `.configwrite_after(stmt, config, field, expr)`
  - Inserts the config statement `config.field = expr` after `stmt`.
- `.delete_config(stmt)`
  - Deletes the configuration statement.



### Other scheduling operations



- `.add_assertion(assertion)`
  - Asserts the truth of the expression `assertion` at the beginning of the procedure.
- `.lift_if(if, n_lifts=1)`
  - Lifts the if statement `if` out of `n_lifts` number of scopes. This is similar to `reorder()`, but for if statements.
- `.eliminate_dead_code(stmt)`
  - Eliminates `if` statement if condition is always True or False. Eliminates `for` statement if condition is always False.
- `.delete_pass()`
  - Deletes a `Pass` statement in the procedure.
- `.reorder_stmts(stmt1, stmt2)`
  - Reorder two adjacent statements `stmt1` and `stmt2`. After this operation, the order will be `stmt2` `stmt1`.
  - `.reorder_before(stmt)`
    - Move the statement `stmt` before the previous statement. This is a shorthand for `reorder_stmts()`.
- `.replace(subproc, stmt)`
  - Replace the statement with a call to `subproc`. This operation is one of our contributions and is explained in detail in the paper.
- `.replace_all(subproc)`
  - Eagerly replace every matching statement with a call to `subproc`.
- `.inline(call_site)`
  - Inline the function call.
- `.is_eq(another_proc)`
  - Returns True if `another_proc` is equivalent to the procedure.
- `.call_eqv(eqv_proc, call_site)`
  - Replace the function call statement of `call_site` with a call to an equivalent procedure `eqv_proc`.
- `.repeat(directive, *args)`
  - Continue to run the directive until it fails. The directive and its arguments are given separately, e.g. `proc.repeat(Procedure.inline, "proc_to_inline(_)")`
- `.simplify()`
  - Simplify the code in the procedure body. Tries to reduce expressions to constants and eliminate dead branches and loops. Uses branch conditions to simplify expressions inside the branches.
- `.rename(new_name)`
  - Rename this procedure to `new_name`.
- `.make_instr(instr_string)`
  - Converts this procedure to an instruction procedure with instruction `instr_string`.
- `.partial_eval(*args, **kwargs)`
  - Specializes this procedure to the given argument values.
- `.set_precision(name, type)`
  - Sets the precision type of `name` to `type`.
- `.set_window(name, is_window)`
  - If `is_window` is True, it sets the buffer `name` to window type, instead of a tensor type.
- `.set_memory(name, mem_type)`
  - Sets a buffer `name`'s memory type to `mem_type`.
