In [103]:
from __future__ import annotations
from exo import proc, DRAM#, i8, i32, f32
from exo.platforms.gemmini import GEMM_ACCUM, GEMM_SCRATCH
from exo.platforms.gemmini import(
    ld_i8,
    ld_i8_vector,
    ld_i8_block,
    matmul_i8,
    st_i8,
    zero_i8,
    zero_acc_i32,
    ld_acc_i8,
    ld_acc_i32,
    st_acc_i8,
    st_acc_i32,
)

In [104]:
def get_exo_proc() -> Procedure:
    @proc
    def generated_operation(
        In: i8[16, 19] @ DRAM,
        Weights: i8[19, 16] @ DRAM,
        Out: i8[16, 16] @ DRAM,
    ):
        for i in seq(0, 16):
            for j in seq(0, 16):
                for k in seq(0, 19):
                    Out[i, j] += In[i, k] * Weights[k, j]

    return generated_operation

In [105]:
PROC = get_exo_proc()

In [106]:
from exo.stdlib.scheduling import (
    stage_mem,
    rename,
    reorder_loops,
    simplify,
    set_memory,
    replace_all,
    set_precision,
    #
)

In [107]:
p = rename(PROC, "testing_proc")
p = reorder_loops(p, "j k")
p = reorder_loops(p, "i k")
p

```python
def testing_proc(In: i8[16, 19] @ DRAM, Weights: i8[19, 16] @ DRAM,
                 Out: i8[16, 16] @ DRAM):
    for k in seq(0, 19):
        for i in seq(0, 16):
            for j in seq(0, 16):
                Out[i, j] += In[i, k] * Weights[k, j]
```

In [114]:
import exo.query_asts as exo_ast
def get_presc(node: exo_ast.Read | exo_ast.Reduce | exo_ast.Alloc | exo_ast.Assign | exo_ast.Const):
    match node:
        case exo_ast.Read (type=t) | \
            exo_ast.Reduce(type=t) | \
            exo_ast.Alloc (type=t) | \
            exo_ast.Assign(type=t) | \
            exo_ast.Const (type=t):
            return t
        case _:
            raise NotImplementedError(node)

In [127]:
from exo.pattern_match import match_pattern
# obj: exo_ast.Proc = PROC.get_ast()
# match_pattern(
#     PROC.body()._impl,
#     "Out"
# )

Block(_root=LoopIR.proc(name='generated_operation', args=[LoopIR.fnarg(name=In_11680, type=LoopIR.Tensor(hi=[LoopIR.Const(val=16, type=LoopIR.Int(), srcinfo=<exo.prelude.SrcInfo object at 0x7fffbc563df0>), LoopIR.Const(val=19, type=LoopIR.Int(), srcinfo=<exo.prelude.SrcInfo object at 0x7fffbc560250>)], is_window=False, type=LoopIR.INT8()), mem=<class 'exo.memory.DRAM'>, srcinfo=<exo.prelude.SrcInfo object at 0x7fffbf7e4700>), LoopIR.fnarg(name=Weights_11681, type=LoopIR.Tensor(hi=[LoopIR.Const(val=19, type=LoopIR.Int(), srcinfo=<exo.prelude.SrcInfo object at 0x7fffbc348610>), LoopIR.Const(val=16, type=LoopIR.Int(), srcinfo=<exo.prelude.SrcInfo object at 0x7fffbc2b1cf0>)], is_window=False, type=LoopIR.INT8()), mem=<class 'exo.memory.DRAM'>, srcinfo=<exo.prelude.SrcInfo object at 0x7fffbc348040>), LoopIR.fnarg(name=Out_11682, type=LoopIR.Tensor(hi=[LoopIR.Const(val=16, type=LoopIR.Int(), srcinfo=<exo.prelude.SrcInfo object at 0x7fffbd81c940>), LoopIR.Const(val=16, type=LoopIR.Int(), srci

In [111]:
from exo import *
from exo.stdlib.scheduling import *
def move_mem_to_gemmini(proc, block_cursor, buff_name: str, window: str, accum: bool = False):
    
    # Copy Buffer into NewBuff before entering the specified block (e.g. a loop)
    proc = stage_mem(proc,
        block_cursor,
        f"{buff_name}{window}",
        f"{buff_name}Staged",
        accum=accum
    )

    # change the memory location of NewBuff
    proc = set_memory(proc,
        f"{buff_name}Staged",
        GEMM_ACCUM if accum else GEMM_SCRATCH
    )

    proc = set_precision(proc,
        f"{buff_name}Staged",
        "i32" if accum else "i8"
    )

    # use the proper insts to do staging and unstaging
    proc = replace_all(proc, zero_i8)
    proc = replace_all(proc, zero_acc_i32)
    
    proc = replace_all(proc, ld_i8)
    proc = replace_all(proc, ld_acc_i8)
    proc = replace_all(proc, ld_acc_i32)

    proc = replace_all(proc, st_i8)
    proc = replace_all(proc, st_acc_i8)
    proc = replace_all(proc, st_acc_i32)

    return proc

p = rename(PROC, "testing_proc")
p = reorder_loops(p, "j k")
p = reorder_loops(p, "i k")

LOOP_NAME = "i"
p = move_mem_to_gemmini(p,
    f"for {LOOP_NAME} in _ : _",
    "Out",
    "[0:16, 0:16]",
    accum=True
)

p

In [113]:
# LOOP_NAME = "k"
# p = move_mem_to_gemmini(p,
#     f"for {LOOP_NAME} in _ : _",
#     "In",
#     "[0:16, 0:19]",
#     accum=False,
# )

# LOOP_NAME = "k"
# p = move_mem_to_gemmini(p,
#     f"for {LOOP_NAME} in _ : _",
#     "Weights",
#     "[0:16, 0:16]"
#     accum=False,
# )

if False:
    p = stage_mem(p, 
        f"for {LOOP_NAME} in _ : _",   # Loop marker
        f"{BUFF_NAME}[0:16, 0:16]",    # Buffer to stage
        f"{BUFF_NAME}Staged",          # NewBuff to stage *into*
        accum=True
    )

    #              NewBuff v     NewMem v
    p = set_memory(p, 
        f"{BUFF_NAME}Staged", # NewBuff
        MEMORY                # NewMem
    )


    # p = replace_all(p, ld_i8)
    # p = replace_all(p, st_i8)
p 

```python
def testing_proc(In: i8[16, 19] @ DRAM, Weights: i8[19, 16] @ DRAM,
                 Out: i8[16, 16] @ DRAM):
    for k in seq(0, 19):
        OutStaged: i32[16 - 0, 16 - 0] @ GEMM_ACCUM
        zero_acc_i32(16, 16, OutStaged[0:16, 0:16])
        for i in seq(0, 16):
            for j in seq(0, 16):
                OutStaged[i - 0, j - 0] += In[i, k] * Weights[k, j]
        for i0 in seq(0, 16 - 0):
            for i1 in seq(0, 16 - 0):
                Out[i0 + 0, i1 + 0] += OutStaged[i0, i1]
```

In [None]:
# p = replace_all(p, ld_i8_vector)
# p = replace_all(p, ld_i8_block)

In [None]:
p

```python
def testing_proc(In: i8[16, 19] @ DRAM, Weights: i8[19, 16] @ DRAM,
                 Out: i8[16, 16] @ DRAM):
    for k in seq(0, 19):
        OutStaged: i8[16 - 0, 16 - 0] @ GEMM_ACCUM
        zero_acc_i32(16, 16, OutStaged[0:16, 0:16])
        for i in seq(0, 16):
            for j in seq(0, 16):
                OutStaged[i - 0, j - 0] += In[i, k] * Weights[k, j]
        for i0 in seq(0, 16 - 0):
            for i1 in seq(0, 16 - 0):
                Out[i0 + 0, i1 + 0] += OutStaged[i0, i1]
```

In [None]:
p = replace_all(p, matmul_i8)

In [None]:
p

```python
def testing_proc(In: i8[16, 19] @ DRAM, Weights: i8[19, 16] @ DRAM,
                 Out: i8[16, 16] @ DRAM):
    for k in seq(0, 19):
        OutStaged: i8[16 - 0, 16 - 0] @ GEMM_ACCUM
        zero_acc_i32(16, 16, OutStaged[0:16, 0:16])
        for i in seq(0, 16):
            for j in seq(0, 16):
                OutStaged[i - 0, j - 0] += In[i, k] * Weights[k, j]
        for i0 in seq(0, 16 - 0):
            for i1 in seq(0, 16 - 0):
                Out[i0 + 0, i1 + 0] += OutStaged[i0, i1]
```