In [1]:
from __future__ import annotations

from typing import *

from exo import proc, DRAM, Memory
from exo import Procedure

import exo.stdlib.scheduling as s
import exo.platforms.rvv as r

In [2]:
def get_exo_proc() -> Procedure:
    @proc
    def generated_operation(
        In: f32[16, 16] @ DRAM,
        Weights: f32[16, 16] @ DRAM,
        Out: f32[16, 16] @ DRAM,
    ):
        for i in seq(0, 16):
            for j in seq(0, 16):
                for k in seq(0, 16):
                    Out[i, j] += In[i, k] * Weights[k, j]

    return generated_operation

PROC = get_exo_proc()

In [3]:
# from exo.platforms.rvv import (
#     RVV,
#     rvv_broadcast_4xf32,    # send reg to 4x
#     rvv_broadcast_4xf32_0,  # send 0 to 4x
#     rvv_broadcast_4xf32_scalar,  # send [0] to 4x
#     rvv_vfmacc_1xf32_4xf32,  # macc 1x4
#     rvv_vfmacc_4xf32_1xf32,  # macc 4x1
#     rvv_vfmacc_4xf32_4xf32,  # macc 4x4
#     rvv_vld_4xf32,            # load 4x
#     rvv_vst_4xf32             # store 4x
# )

In [31]:
import re
def move_symbol_to_mem_for_block(
        proc: Procedure,
        block_cursor,
        symbol: str,
        window: str,
        new_mem: type[Memory],
        post_process: Callable[[Procedure], Procedure]
    ) -> Procedure:
    new_symbol = f"{symbol}Staged"
    p = s.stage_mem(proc, block_cursor, f"{symbol}{window}", new_symbol)
    p = s.set_memory(p, new_symbol, new_mem)
    p = s.simplify(p)

    raw_src_code = str(p)

    ans = re.search(f'{new_symbol}: f32\[(.*)\]', raw_src_code) # find the allocation
    if ans is None:
        print(raw_src_code)
        raise ValueError(ans)

    dims = [ int(d) for d in ans.group(1).split(", ") ] # and pull out the dimensions

    p = s.divide_dim(p, f"{new_symbol} : _", len(dims)-1, 4) # make the trailing dimension of new_symbol 4
    dim = 1
    for d in dims[:-1]: # merge all the dimensions to get [_, 4] as our shape
        dim *= d
        p = s.mult_dim(p, new_symbol, 0, 1)
    dim *= dims[-1] // 4

    p = s.unroll_buffer(p, new_symbol, 0) # NOW WE MAKE THE THING
    # print(dim)
    # p = post_process(p)
    # p = s.replace_all(p, r.rvv_vld_4xf32)
    # p = s.replace_all(p, r.rvv_vst_4xf32)

    return p

def replace_rvv_ld_st(proc: Procedure) -> Procedure:
    return ( s.replace_all(
        s.replace_all(
            proc, 
            r.rvv_vld_4xf32
        )
        , r.rvv_vst_4xf32)
    )

In [32]:
p = s.rename(PROC, "matmul")

p = move_symbol_to_mem_for_block(p,
    "for i in _ : _",
    "In", "[0:16, 0:16]",
    r.RVV,
    replace_rvv_ld_st
)
for _ in range(0):
    p = move_symbol_to_mem_for_block(p,
        "for i in _ : _",
        "Out", "[0:16, 0:16]",
        r.RVV,
        replace_rvv_ld_st
    )
p = s.simplify(p)
print(str(p))

64
def matmul(In: f32[16, 16] @ DRAM, Weights: f32[16, 16] @ DRAM,
           Out: f32[16, 16] @ DRAM):
    InStaged: f32[64, 4] @ RVV
    for i0 in seq(0, 16):
        for i1 in seq(0, 16):
            InStaged[4 * i0 + i1 / 4, i1 % 4] = In[i0, i1]
    for i in seq(0, 16):
        for j in seq(0, 16):
            for k in seq(0, 16):
                Out[i, j] += InStaged[4 * i + k / 4, k % 4] * Weights[k, j]


In [21]:
import re
# src = str(p)
raw_src_code = str(p)
sym = "InStaged"



['16', '16']

In [6]:
def extract_c_code(proc: Procedure) -> str:
    return proc.c_code_str().split("#include <stdlib.h>")[-1]
extract_c_code(p)

MemGenError: /tmp/ipykernel_54837/4270879120.py:8:8: RVV vectors of type float must be 4-wide, got ['16', '16']