In [1]:
from __future__ import annotations
import os
import sys
from exo import proc
from exo.libs.memories import *
from exo.platforms.x86 import *
from exo.stdlib.scheduling import *
from exo.stdlib.stdlib import *
from helpers import *

from hw_lib import *

N = 128
@proc
def matmul(
    A: i32[N, N] @ DRAM, 
    B: i32[N, N] @ DRAM, 
    C: i32[N, N] @ DRAM
):
    for i in seq(0, 16):
        for j in seq(0, 16):
            for k in seq(0, 16):
                for ti in seq(0, 4):
                    for tj in seq(0, 4):
                        for tk in seq(0, 4):
                            C[i*4+ti,j*4+tj] += A[i*4+ti,k*4+tk] * B[j*4+tj,k*4+tk]

In [2]:
# Staging data tile load 
        
matmul = simplify(stage_mem(matmul, 'for ti in _:_', f'A[i*4:i*4+4, k*4:k*4+4]', "A_tile"))
print(matmul)

def matmul(A: i32[128, 128] @ DRAM, B: i32[128, 128] @ DRAM,
           C: i32[128, 128] @ DRAM):
    for i in seq(0, 16):
        for j in seq(0, 16):
            for k in seq(0, 16):
                A_tile: i32[4, 4] @ DRAM
                for i0 in seq(0, 4):
                    for i1 in seq(0, 4):
                        A_tile[i0, i1] = A[i0 + 4 * i, i1 + 4 * k]
                for ti in seq(0, 4):
                    for tj in seq(0, 4):
                        for tk in seq(0, 4):
                            C[ti + 4 * i,
                              tj + 4 * j] += A_tile[ti, tk] * B[tj + 4 * j,
                                                                tk + 4 * k]


In [3]:
# Instruct Exo to use an RVM_TILE memory
matmul = simplify(set_memory(matmul, "A_tile", RVM_TILE))
print(matmul)

def matmul(A: i32[128, 128] @ DRAM, B: i32[128, 128] @ DRAM,
           C: i32[128, 128] @ DRAM):
    for i in seq(0, 16):
        for j in seq(0, 16):
            for k in seq(0, 16):
                A_tile: i32[4, 4] @ RVM_TILE
                for i0 in seq(0, 4):
                    for i1 in seq(0, 4):
                        A_tile[i0, i1] = A[i0 + 4 * i, i1 + 4 * k]
                for ti in seq(0, 4):
                    for tj in seq(0, 4):
                        for tk in seq(0, 4):
                            C[ti + 4 * i,
                              tj + 4 * j] += A_tile[ti, tk] * B[tj + 4 * j,
                                                                tk + 4 * k]


In [4]:
# Replace for loop with special procedure
matmul = replace(matmul, "for i0 in _:_", rvm_mld)
print(matmul)

def matmul(A: i32[128, 128] @ DRAM, B: i32[128, 128] @ DRAM,
           C: i32[128, 128] @ DRAM):
    for i in seq(0, 16):
        for j in seq(0, 16):
            for k in seq(0, 16):
                A_tile: i32[4, 4] @ RVM_TILE
                rvm_mld(A_tile[0:4, 0:4], A[4 * i + 0:4 * i + 4,
                                            4 * k + 0:4 * k + 4])
                for ti in seq(0, 4):
                    for tj in seq(0, 4):
                        for tk in seq(0, 4):
                            C[ti + 4 * i,
                              tj + 4 * j] += A_tile[ti, tk] * B[tj + 4 * j,
                                                                tk + 4 * k]


In [5]:
# Setting up kernel tile load
matmul = stage_mem(matmul, 'for ti in _:_', f'B[j*4:j*4+4,k*4:k*4+4]', "B_tile")
matmul = set_memory(matmul, "B_tile", RVM_TILE)
matmul = replace(matmul, "for i0 in _:_", rvm_mld)

# Setting up output tiles
matmul = stage_mem(matmul, 'for k in _:_', f'C[i*4:i*4+4,j*4:j*4+4]', "C_tile")
matmul = set_memory(matmul, "C_tile", RVM_TILE)
matmul = simplify(matmul)
matmul = replace(matmul, "for i0 in _:_", rvm_mld)
matmul = replace(matmul, "for ti in _:_", rvm_mmasa)
matmul = replace(matmul, "for i0 in _:_", rvm_mst)

print(matmul)

def matmul(A: i32[128, 128] @ DRAM, B: i32[128, 128] @ DRAM,
           C: i32[128, 128] @ DRAM):
    for i in seq(0, 16):
        for j in seq(0, 16):
            C_tile: i32[4, 4] @ RVM_TILE
            rvm_mld(C_tile[0:4, 0:4], C[4 * i + 0:4 * i + 4,
                                        4 * j + 0:4 * j + 4])
            for k in seq(0, 16):
                A_tile: i32[4, 4] @ RVM_TILE
                rvm_mld(A_tile[0:4, 0:4], A[4 * i:4 + 4 * i, 4 * k:4 + 4 * k])
                B_tile: i32[4, 4] @ RVM_TILE
                rvm_mld(B_tile[0:4, 0:4], B[4 * j:4 + 4 * j, 4 * k:4 + 4 * k])
                rvm_mmasa(C_tile[0:4, 0:4], B_tile[0:4, 0:4], A_tile[0:4, 0:4])
            rvm_mst(C_tile[0:4, 0:4], C[4 * i + 0:4 * i + 4,
                                        4 * j + 0:4 * j + 4])
