In [1]:
from __future__ import annotations
from exo import proc, DRAM, Procedure, SchedulingError
import exo.query_asts as exo_ast
import multiprocessing as mp
import concurrent.futures

import sys
from itertools import chain
from functools import reduce
from typing import Any

from psutil import Process, cpu_count
Process().cpu_affinity(list(range(cpu_count())))

from evaluate import PersistentQueue, QUEUE_FILE


sys.setrecursionlimit(10000)

BRANCHING_THRESHOLD: int = 3
CONTROL_STRUCTURE = (exo_ast.For, exo_ast.If, exo_ast.Proc)
LOOP_SPLIT_SIZES: int = 4

In [2]:
from pprint import pprint
import cProfile

def profileit(do_print: bool = False):
    def inner(func):
        def wrapper(*args, **kwargs):
            prof = cProfile.Profile()
            retval = prof.runcall(func, *args, **kwargs)

            if do_print:
                prof.print_stats(sort="cumtime")
            else:
                prof.dump_stats(f"{func.__name__}.profile")

            return retval
        return wrapper
    return inner

### Query AST Classes
---
Exo ast nodes are organized into a hierarchy of dataclasses:

- QueryAST
  - `Proc`  ( name : str, args : list[FnArg], assertions : list[Expr],
                        body : list[Stmt],  instruction : Optional[str] )
  - `Stmt`
    - `Assign`    ( name : str,   lhs_type : Type,    idx : list[Expr],
                              rhs  : Expr )
    - `Reduce`    ( name : str,   lhs_type : Type,    idx : list[Expr],
                              rhs  : Expr )
    - `WriteConfig` ( config : Config, field : str,
                              rhs  : Expr )
    - `Pass`      ()
    - `If`        ( cond : Expr,  body : list[Stmt],  orelse : list[Stmt] )
    - `For`       ( name : str,   lo   : Expr,        hi : Expr,
                              body : list[Stmt],  is_par : bool )
    - `Alloc`     ( name : str,   type : Type,        memory : Optional[Memory] )
    - `Call`      ( proc : str,   args : list[Expr] )
    - `WindowStmt`( name : str,   rhs  : Expr )
  - `Expr`
    - `Read`    ( name : str,   idx  : list[Expr],    type : Type )
    - `Const`   ( val  : Any,   type : Type )
    - `USub`    ( arg  : Expr,  type : Type )
    - `BinOp`   ( op   : str,   lhs  : Expr,
                            rhs  : Expr,          type : Type )
    - `BuiltIn` ( func : str,   args : list[Expr],    type : Type )
    - `WindowExpr`( name : str, idx : list[WAccess],  type : Type )
    - `StrideExpr`( name : str, dim : int,            type : Type )
    - `ReadConfig`( config : Config, field : str,     type : Type )
  - `WAccess`
    - `Interval`( lo : Expr, hi : Expr )
    - `Point`( pt : Expr )
  - `FnArg` ( name : str, type : Type, memory : Optional[Memory] )
  - `Type`
    - `R`()
    - `f16`()
    - `f32`()
    - `f64`()
    - `i8`()
    - `i32`()
    - `bool`()
    - `int`()
    - `index`()
    - `size`()
    - `stride`()
    - `tensor`( hi : list[Expr], is_window : bool, type : Type )

## Code

### Infrastructure

In [3]:
def get_exo_proc() -> Procedure:
    @proc
    def generated_operation(
        In: i8[16, 16] @ DRAM,
        Weights: i8[16, 16] @ DRAM,
        Out: i8[16, 16] @ DRAM,
    ):
        for i in seq(0, 16):
            for j in seq(0, 16):
                for k in seq(0, 16):
                    Out[i, j] += In[i, k] * Weights[k, j]

    return generated_operation

In [4]:
def extract_c_code(proc: Procedure) -> str:
    return proc.c_code_str().split("#include <stdlib.h>")[-1]

In [5]:
evaluation_queue = PersistentQueue(persistence_file=QUEUE_FILE, overwrite=True)

In [6]:
PROC_AST: exo_ast.Proc = get_exo_proc().get_ast() # pyright: ignore[reportAssignmentType]

In [7]:
pprint(PROC_AST.body)

[For(name='i',
     lo=Const(val=0, type=int()),
     hi=Const(val=16, type=int()),
     body=[For(name='j',
               lo=Const(val=0, type=int()),
               hi=Const(val=16, type=int()),
               body=[For(name='k',
                         lo=Const(val=0, type=int()),
                         hi=Const(val=16, type=int()),
                         body=[Reduce(name='Out',
                                      lhs_type=i8(),
                                      idx=[Read(name='i', idx=[], type=index()),
                                           Read(name='j',
                                                idx=[],
                                                type=index())],
                                      rhs=BinOp(op='*',
                                                lhs=Read(name='In',
                                                         idx=[Read(name='i',
                                                                   idx=[],
                    

In [8]:
def extract_fors(mod: exo_ast.QueryAST, can_split: bool = True) -> list[exo_ast.For]:
    """Produce a list of the For nodes in the ast.Module in depth-wise order."""
    if not isinstance(mod, CONTROL_STRUCTURE):
        return [ ]

    match mod:
        case exo_ast.For(_):
            if can_split and len(mod.body) > BRANCHING_THRESHOLD:
                with mp.Pool() as pool:
                    return list(chain([mod], 
                        *pool.starmap(extract_fors, [(item, False) for item in mod.body])
                    ))
            return list(chain([mod], *[extract_fors(item, can_split) for item in mod.body]))
        case exo_ast.Proc(_) | exo_ast.If(_):
            if can_split and len(mod.body) > BRANCHING_THRESHOLD:
                with mp.Pool() as pool:
                    return list(chain(
                        *pool.starmap(extract_fors, [(item, False) for item in mod.body]
                    )))
            return list(chain(*[extract_fors(item, can_split) for item in mod.body]))
        case _:
            raise ValueError

extract_fors(PROC_AST)

[For(name='i', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[For(name='j', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[For(name='k', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[Reduce(name='Out', lhs_type=i8(), idx=[Read(name='i', idx=[], type=index()), Read(name='j', idx=[], type=index())], rhs=BinOp(op='*', lhs=Read(name='In', idx=[Read(name='i', idx=[], type=index()), Read(name='k', idx=[], type=index())], type=i8()), rhs=Read(name='Weights', idx=[Read(name='k', idx=[], type=index()), Read(name='j', idx=[], type=index())], type=i8()), type=i8()))], is_par=False)], is_par=False)], is_par=False),
 For(name='j', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[For(name='k', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[Reduce(name='Out', lhs_type=i8(), idx=[Read(name='i', idx=[], type=index()), Read(name='j', idx=[], type=index())], rhs=BinOp(op='*', lhs=Read(name='In', idx=[Read(name='i', i

In [9]:
union = lambda iter_of_sets: reduce(lambda a, b: a | b, iter_of_sets, set())

In [10]:
def get_swappable_loops(ast: exo_ast.QueryAST, can_split: bool = True) -> set[tuple[str, str]]:
    """Return a set of all swappable for loops"""
    if not isinstance(ast, CONTROL_STRUCTURE):
        return set([])

    match ast:
        case exo_ast.For(name=name, body=[exo_ast.For(name=inner_name)]):
            return set([(name, inner_name)]) \
                 | get_swappable_loops(ast.body[0], can_split)
        case exo_ast.Proc(body=b) | exo_ast.If(body=b) | exo_ast.For(body=b):
            if can_split and len(b) > BRANCHING_THRESHOLD:
                with mp.Pool() as pool:
                    return union(pool.starmap(get_swappable_loops, 
                        [ (node, False) for node in b ] 
                    ))

            return union([
                get_swappable_loops(node, can_split)
                for node in b
            ])
        case _:
            raise ValueError

get_swappable_loops(PROC_AST)

{('i', 'j'), ('j', 'k')}

In [11]:
def find_innermost_loops(mod: exo_ast.QueryAST, can_split: bool = True) -> list[exo_ast.For]:
    if not isinstance(mod, CONTROL_STRUCTURE):
        return [ ]

    match mod:
        case exo_ast.For(body=[exo_ast.For(_)]) | exo_ast.Proc(_) | exo_ast.If(_):
            if can_split and len(mod.body) > BRANCHING_THRESHOLD:
                with mp.Pool() as pool:
                    return list(chain(
                        *pool.starmap(find_innermost_loops, [(item, False) for item in mod.body]
                    )))
            return list(chain(*[find_innermost_loops(item, can_split) for item in mod.body]))
        case exo_ast.For(_):
            return [mod]
        case _:
            raise ValueError

find_innermost_loops(PROC_AST)

[For(name='k', lo=Const(val=0, type=int()), hi=Const(val=16, type=int()), body=[Reduce(name='Out', lhs_type=i8(), idx=[Read(name='i', idx=[], type=index()), Read(name='j', idx=[], type=index())], rhs=BinOp(op='*', lhs=Read(name='In', idx=[Read(name='i', idx=[], type=index()), Read(name='k', idx=[], type=index())], type=i8()), rhs=Read(name='Weights', idx=[Read(name='k', idx=[], type=index()), Read(name='j', idx=[], type=index())], type=i8()), type=i8()))], is_par=False)]

In [12]:
from exo.stdlib.scheduling import (
    simplify,
    unroll_loop,
    reorder_loops,
    divide_loop,
    replace_all,
    
    stage_mem,
    set_memory,
)

class ScheduledProc:
    def __init__(self, exo_sched_func, args: list = [], kwargs: dict[str, Any] = {}):
        self.sched_func = exo_sched_func
        self.func_args = args
        self.func_kwargs = kwargs

    def _apply_transform(self, proc: Procedure) -> Procedure:
        func = self.sched_func
        # match self.sched_func:
        #     case "simplify":
        #         from exo.stdlib.scheduling import simplify
        #         func = simplify
        #     case "unroll_loop":
        #         from exo.stdlib.scheduling import unroll_loop
        #         func = unroll_loop
        #     case "reorder_loops":
        #         from exo.stdlib.scheduling import reorder_loops
        #         func = reorder_loops
        #     case "divide_loop":
        #         from exo.stdlib.scheduling import divide_loop
        #         func = divide_loop
        #     case "replace_all":
        #         from exo.stdlib.scheduling import replace_all
        #         func = replace_all
        #     case _:
        #         raise NotImplementedError

        # if func is None:
        #     raise ValueError
        return func(
            proc, 
            *self.func_args,
            **self.func_kwargs
        )

    def __repr__(self):
        return (
            f"{self.sched_func.__name__}(<proc>, "
          + ", ".join(map(str, self.func_args))
          + ", "
          + ", ".join([f"{key} = {val}" for key, val in self.func_kwargs.items()])
          + ")"
        )

    def __call__(self, proc: Procedure) -> Procedure:
        try:
            return self._apply_transform(proc)
        except SchedulingError as e:
            print(f"WARNING: Encountered {e}")
            return proc

### Transforms

In [20]:
p = ScheduledProc(unroll_loop, ["k"])(get_exo_proc())
p = ScheduledProc(unroll_loop, ['j'])(p)
# p = ScheduledProc("unroll_loop", ['i'])(p)
# find_innermost_loops(p.get_ast())
s = ScheduledProc(unroll_loop, ['i'])
# with concurrent.futures.ProcessPoolExecutor() as executor:
#     out = list(executor.map(s, [p]))
# out

In [15]:
# p = ScheduledProc(exo.API_scheduling.simplify, [])(p)x

In [16]:
def get_loop_reorders(ast: exo_ast.Proc) -> list[ScheduledProc]:
    loop_pairs = get_swappable_loops(ast)
    return [
        ScheduledProc(
            reorder_loops,
            [" ".join(list(loop_pair))], 
            {}
        ) for loop_pair in loop_pairs
    ]

get_loop_reorders(PROC_AST)

[reorder_loops(<proc>, j k, ), reorder_loops(<proc>, i j, )]

In [17]:
def get_loop_unrolls(ast: exo_ast.Proc) -> list[ScheduledProc]:

    lowest_loops = find_innermost_loops(ast)
    return [
        ScheduledProc(unroll_loop, [loop.name])
        for loop in lowest_loops
    ]

get_loop_unrolls(PROC_AST)

[unroll_loop(<proc>, k, )]

In [18]:
def get_loop_splits(ast: exo_ast.Proc) -> list[ScheduledProc]:
    lowest_loops = find_innermost_loops(ast)

    splits: list[ScheduledProc] = [ ]
    for loop in lowest_loops:
        if not (isinstance(loop.hi, exo_ast.Const) and isinstance(loop.lo, exo_ast.Const)):
            continue # we can't split loops without constant bounds

        for split_const in range(LOOP_SPLIT_SIZES, loop.hi.val - loop.lo.val, LOOP_SPLIT_SIZES):

            splits.append(
                ScheduledProc(
                    divide_loop,
                    [
                        loop.name,
                        split_const,
                        [f"{loop.name}_out", f"{loop.name}_in"]
                    ], {
                        "tail" : "cut",
                    })
            )

    return splits

get_loop_splits(PROC_AST)

[divide_loop(<proc>, k, 4, ['k_out', 'k_in'], tail = cut),
 divide_loop(<proc>, k, 8, ['k_out', 'k_in'], tail = cut),
 divide_loop(<proc>, k, 12, ['k_out', 'k_in'], tail = cut)]

In [19]:
from exo.platforms.gemmini import (
    do_ld_i8_block_id1,
    do_ld_i8_vector,
    do_ld_i8,

    do_matmul_i8,

    st_acc_i32,
    do_st_acc_i8,

    config_ld_i8,
    config_ld_i8_id1,
    config_matmul,

)

from exo.API import check_eqv_proc

def find_eq_gemmini_subtrees(ast: exo_ast.QueryAST) -> list[exo_ast.QueryAST]:
    ...

def get_gemmini_insts(_: exo_ast.Proc) -> list[ScheduledProc]:
    return [
        ScheduledProc(replace_all, [gemmini_inst])
        for gemmini_inst in [
            do_ld_i8_block_id1,
            do_matmul_i8

        ]
    ]

ALL_INSTS = [
    "acc_scale",
] + [ # i8 loading config insts
    "config_ld_i8" + suff
    for suff in [
        "", "_id1", "_id2", "_s2_id1"
    ]
] + [ # i8 loads
    "do_ld_i8" + suff # ld_i8_prototype
    for suff in [
        "", "_id1", "_id2", "_s2_id1", 
        "_block_id1", 
        "_block_id2", 
        "_vector",
    ]
] + [ # i8 block loads (not "do")
    "ld_i8_block" + suff
    for suff in [
        "",
        "_id1", "_id1_v2", "_id1_s2_v2",
        "_id2", "_id2_v2", "_id2_s2_v2",
    ]
] + [ # remaining i8 loads (not "do")
    "ld_i8" + suff # this one is the same as do_ld_i8, but with a pass statement
    for suff in [
        "", "_v2", "_s2",
        "_id1", "_id1_v2", "_id1_s2_v2",
        "_id2", "_id2_v2", "_id2_s2_v2",
        "_vector",
    ]
] + [
    "do_ld_acc_i32",
    "do_ld_acc_i32_vector",

    "config_ld_acc_i32_vector",
    
    "ld_acc_i8",
    "ld_acc_i32",
    "ld_acc_i32_vector",
    "ld_acc_i32_vector_v2",


    "zero_block_id2",
    
    "st_i8",

    "clamp",

    "config_st_acc_i8",
    "do_st_acc_i8",

    "st_acc_i8",
    "st_acc_i8_v2",
    "st_acc_i8_s2_v2",
    
    "st_acc_i32",
    
    "config_zero",

    "do_zero_i8",
    "do_zero_i8_vector",

    "do_zero_acc_i32",

    "zero_acc_i32_v2",
    "zero_acc_i32",
    
    
    "zero_i8",
    "zero_i8_v2",
    "zero_i8_vector",
    "zero_i8_vector_v2",
    

    "config_matmul",

    "do_matmul_i8",
    # "matmul_i8_v2",
    # "matmul_i8",

    # "do_matmul_acc_i8",
    # "matmul_acc_i8_v2",
    # "matmul_acc_i8",
]

ImportError: cannot import name 'stride' from 'exo' (/workspaces/project/chipyard/.conda-env/lib/python3.10/site-packages/exo/__init__.py)

In [None]:
def get_loop_fusions(ast: exo_ast.Proc):
    ...

In [None]:
def get_all_exo_tsfs(proc: Procedure, include_simplify: bool = True, debug: bool = False):
    root_ast: exo_ast.Proc = proc.get_ast() # pyright: ignore[reportAssignmentType]
    tsfs: list[ScheduledProc] = [ ]

    for expansion in [
        get_loop_reorders,
        get_loop_unrolls,
        # get_loop_splits,
        get_gemmini_insts,
    ]:
        if debug:
            print(f"\t> Doing expansion: {expansion.__name__}")

        new_tsfs = expansion(root_ast)

        if debug:
            print(f"\t  Found {len(new_tsfs)}")

        tsfs.extend(new_tsfs)

    if include_simplify:
        tsfs.append(ScheduledProc(simplify))

    return tsfs

p = get_exo_proc()
tsf = get_all_exo_tsfs(get_exo_proc())[0]
tsf(p)

```python
def generated_operation(In: i8[16, 16] @ DRAM, Weights: i8[16, 16] @ DRAM,
                        Out: i8[16, 16] @ DRAM):
    for i in seq(0, 16):
        for k in seq(0, 16):
            for j in seq(0, 16):
                Out[i, j] += In[i, k] * Weights[k, j]
```

In [None]:
def apply_scheduled_proc(sp: ScheduledProc, p: Procedure):
    return sp(p), p

@profileit(do_print=True)
def search_procs(proc: Procedure, debug: bool = False):
    og_c_code = extract_c_code(proc)
    evaluation_queue.put(og_c_code)
    seen_procs = {
        og_c_code : proc
    }

    def deeper(p: Procedure, can_split: bool = False, can_use_simplify: bool = False, debug_prefix: str = ""):
        nonlocal seen_procs

        all_tsfs = get_all_exo_tsfs(
            p, include_simplify=can_use_simplify
        )

        if can_split and len(all_tsfs) > 1:
            print("splitting")
            with concurrent.futures.ProcessPoolExecutor() as executor:
                futures = [
                    executor.submit(apply_scheduled_proc, t, p)
                    for t in all_tsfs
                ]
                new_procs = [
                    f.result() # the un-pickle-ing of the return values doesn't work:
                               #    - When pickling, the symbols get packed in as they 
                               #        understand themselves (i.e. <thing>.__getstate__)
                               #    - The problem is that old objects don't expect to be a 
                               #        part of a more complex construct and so return a
                               #        __main__.<thing> value for that
                               #    - However, for us, it's often more complicated -- including
                               #        imports and submodules, e.g. exo.<thing> due to exo 
                               #        importing <thing> all the way up to its root __init__
                               #    - So when pickle tries to write to what <thing> thinks it is
                               #        and finds that that name doesn't actually exist in our
                               #        namespace, it errors
                    for f in futures
                ]
                # new_procs = list( # this is a much shorter version of the above (if it worked lol)
                #     executor.map(apply_scheduled_proc, [[t, p] for t in all_tsfs])
                # )
        else:
            new_procs = [
                (tsf(p), tsf) for tsf in 
                all_tsfs
            ]
        if debug: print(f"{debug_prefix}Found {len(new_procs)} potential procs")
        actually_new_procs = []

        # here is where we would add the secret sauce:
        #new_procs = sort_by_fancy_heuristic(new_procs)
        
        for new_proc, tsf in new_procs:

            c_code = extract_c_code(new_proc)
        
            if c_code in seen_procs.keys():
                continue

            seen_procs[c_code] = new_proc
            evaluation_queue.put(c_code)
            actually_new_procs.append((new_proc, tsf))
        
        if debug: print(f"{debug_prefix}branching into {len(actually_new_procs)} new procs")

        for new_proc, tsf in actually_new_procs:
            deeper(
                new_proc,
                debug_prefix = debug_prefix+"\t"
            )
    
    deeper(proc)

    return list(seen_procs.values())

procs = search_procs(get_exo_proc())
procs, len(procs)

         67650855 function calls (65360679 primitive calls) in 28.773 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   30.823   30.823 1352577613.py:4(search_procs)
     24/1    0.031    0.001   30.807   30.807 1352577613.py:12(deeper)
       37    0.000    0.000   14.626    0.395 3974458107.py:1(extract_c_code)
       37    0.007    0.000   14.625    0.395 API.py:339(c_code_str)
       37    0.005    0.000   14.619    0.395 LoopIR_compiler.py:363(compile_to_strings)
105338/101750    0.215    0.000    8.510    0.000 LoopIR.py:682(map_s)
       24    0.000    0.000    8.227    0.343 1352577613.py:34(<listcomp>)
       36    0.000    0.000    8.227    0.229 3531590211.py:45(__call__)
       36    0.001    0.000    8.227    0.229 3531590211.py:7(_apply_transform)
       36    0.002    0.000    8.225    0.228 API_scheduling.py:79(__call__)
       37    0.012    0.000    8.145    0.220 LoopIR_compi

([<exo.API.Procedure at 0x7fffbf90a110>,
  <exo.API.Procedure at 0x7fffbe82cca0>,
  <exo.API.Procedure at 0x7fffd138d180>,
  <exo.API.Procedure at 0x7fffd138cf70>,
  <exo.API.Procedure at 0x7fffd127b580>,
  <exo.API.Procedure at 0x7fffd11115a0>,
  <exo.API.Procedure at 0x7fffd119c670>,
  <exo.API.Procedure at 0x7fffd102c850>,
  <exo.API.Procedure at 0x7fffd1083e50>,
  <exo.API.Procedure at 0x7fffd0f3dae0>,
  <exo.API.Procedure at 0x7fffd0e27b20>,
  <exo.API.Procedure at 0x7fffd0cffdc0>,
  <exo.API.Procedure at 0x7fffd0ea60e0>,
  <exo.API.Procedure at 0x7fffd0cfd780>,
  <exo.API.Procedure at 0x7fffd0c6c580>,
  <exo.API.Procedure at 0x7fffd0db9990>,
  <exo.API.Procedure at 0x7fffd1012560>,
  <exo.API.Procedure at 0x7fffd00bf910>,
  <exo.API.Procedure at 0x7fffbe838be0>,
  <exo.API.Procedure at 0x7fffd0c6eda0>,
  <exo.API.Procedure at 0x7fffd0ccc850>,
  <exo.API.Procedure at 0x7fffbe00ffd0>,
  <exo.API.Procedure at 0x7fffd0db9090>,
  <exo.API.Procedure at 0x7fffd10dbd00>],
 24)

# Exo's scheduling API


## Top-level Python function decorator

1. `@proc` - decorates a Python function which is parsed and compiled as Exo. Replaces
   the function with a `Procedure` object.
2. `@instr` - same as `@proc`, but accepts a hardware instruction as a format string.
3. `@config` - decorates a Python class which is parsed and compiled as an Exo
   configuration object



## Procedure object methods



**Introspection operations**

- `.name()` returns the procedure name.
- `.check_effects()` forces Exo to run effect checking on the procedure.
- `.show_effects()` prints the effects of the procedure.
- `.show_effect(stmt)` prints the effect of the `stmt` in the procedure.
- `.is_instr()` returns `true` if the procedure has a hardware instruction string.
- `.get_instr()` returns the hardware instruction string.
- `.get_ast()` returns a `QAST`, which is an AST representation suitable for
  introspection.



**Execution / interpretation operations**

- `.compile_c(directory, filename)` compiles the procedure into C and stores
  in `filename` in the `directory`.
- `.interpret(**args)` runs Exo interpreter on the procedure.



## Scheduling operations on Procedure objects



### Loop related operations



- `.split(loop, split_const, iter_vars, tail='guard', perfect=False)`
  - Splits `loop` into an outer and an inner loop. The inner loop bound is `split_const` and the outer and inner loop names are specified by a list of strings `iter_vars`. If `perfect` is True, it will not introduce a tail case. `tail` specifies the tail strategies, where the options are `guard`, `cut`, and `cut_and_guard`.
- `.fuse_loop(loop1, loop2)`
  - Fuses two adjacent loops with a common iteration variable.
- `.partition_loop(loop, num)`
  - Partitions `loop` into two loops, the first running between `0` and `num` and the second between `num+1` and `loop`'s original bound.
- `.reorder(loop1, loop2)`
  - Reorders two nested loops. `loop2` should be nested directly inside `loop1`. `loop1` will be nested inside `loop2` after this operation.
- `.unroll(loop)`
  - Unrolls the loop. The loop needs to have a constant bound.
- `.fission_after(stmt, n_lifts=1)`
  - Fissions the `n_lifts` number of loops around the `stmt`. The fissioned loops around the `stmt` need to be directly nested with each other and the statements before and after the `stmt` should not have any allocation dependencies.
- `.remove_loop(loop)`
  - Replaces the loop with its body if the body is idempotent. The system must be able to prove that the loop runs at least once.



### Buffer related operations



- `.reuse_buffer(buf1, buf2` 
  - Reuses a buffer `buf1` in the use site of `buf2` and removes the allocation of `buf2`
- `.inline_window(win_stmt)`
  - Removes the window statement `win_stmt`, which is an alias to the window, and inlines the windowing in its use site
- `.expand_dim(stmt, alloc_dim, indexing)`
  - Expands the dimension of the allocation statement `stmt` with dimension `alloc_dim` of indexing `indexing`
- `.bind_expr(new_name, expr)`
  - Binds the right hand side expression `expr` to a newly allocated buffer named `new_name`
- `.stage_mem(win_expr, new_name, stmt_start, stmt_end=None)`
  - Stages the buffer `win_expr` to the new window expression `new_name` in statement block (`stmt_start` to `stmt_end`), and adds an initialization loop and a write-back loop
- `.rearrange_dim(alloc, dimensions)`
  - Takes an allocation statement and a list of integers to map the dimension. It rearranges the dimensions of `alloc` in `dimension` order. E.g., if `alloc` were `foo[N,M,K]` and the `dimension` were `[2,0,1]`, it would become `foo[K,N,M]` after this operation.
- `.lift_alloc(alloc, n_lifts=1, keep_dims=False)`
  - Lifts the allocation statement `alloc` out of `n_lifts` number of scopes. If and For statements are the only statements in Exo which introduce a scope. When lifting the allocation out of a for loop, it will expand its dimension to the loop bound if `keep_dims` is True.



### Config related operations



- `.bind_config(expr, config, field)`
  - Binds the right hand side `expr` to `config.field`. It will replace the use site of `expr` with `config.field` and introduces a config statement of `config.field = expr`.
- `.configwrite_root(config, field, expr)`
  - Inserts the config statement `config.field = expr` in the beginning of the procedure.
- `.configwrite_after(stmt, config, field, expr)`
  - Inserts the config statement `config.field = expr` after `stmt`.
- `.delete_config(stmt)`
  - Deletes the configuration statement.



### Other scheduling operations



- `.add_assertion(assertion)`
  - Asserts the truth of the expression `assertion` at the beginning of the procedure.
- `.lift_if(if, n_lifts=1)`
  - Lifts the if statement `if` out of `n_lifts` number of scopes. This is similar to `reorder()`, but for if statements.
- `.eliminate_dead_code(stmt)`
  - Eliminates `if` statement if condition is always True or False. Eliminates `for` statement if condition is always False.
- `.delete_pass()`
  - Deletes a `Pass` statement in the procedure.
- `.reorder_stmts(stmt1, stmt2)`
  - Reorder two adjacent statements `stmt1` and `stmt2`. After this operation, the order will be `stmt2` `stmt1`.
  - `.reorder_before(stmt)`
    - Move the statement `stmt` before the previous statement. This is a shorthand for `reorder_stmts()`.
- `.replace(subproc, stmt)`
  - Replace the statement with a call to `subproc`. This operation is one of our contributions and is explained in detail in the paper.
- `.replace_all(subproc)`
  - Eagerly replace every matching statement with a call to `subproc`.
- `.inline(call_site)`
  - Inline the function call.
- `.is_eq(another_proc)`
  - Returns True if `another_proc` is equivalent to the procedure.
- `.call_eqv(eqv_proc, call_site)`
  - Replace the function call statement of `call_site` with a call to an equivalent procedure `eqv_proc`.
- `.repeat(directive, *args)`
  - Continue to run the directive until it fails. The directive and its arguments are given separately, e.g. `proc.repeat(Procedure.inline, "proc_to_inline(_)")`
- `.simplify()`
  - Simplify the code in the procedure body. Tries to reduce expressions to constants and eliminate dead branches and loops. Uses branch conditions to simplify expressions inside the branches.
- `.rename(new_name)`
  - Rename this procedure to `new_name`.
- `.make_instr(instr_string)`
  - Converts this procedure to an instruction procedure with instruction `instr_string`.
- `.partial_eval(*args, **kwargs)`
  - Specializes this procedure to the given argument values.
- `.set_precision(name, type)`
  - Sets the precision type of `name` to `type`.
- `.set_window(name, is_window)`
  - If `is_window` is True, it sets the buffer `name` to window type, instead of a tensor type.
- `.set_memory(name, mem_type)`
  - Sets a buffer `name`'s memory type to `mem_type`.
