# Matmul Dataflow

##


In [None]:
import tvm
from tvm import te 
# ---------------
# op config
# ---------------
M = 256
N = 128
K = 512
# M = 8
# N = 8
# K = 4
# ---------------
# tile size
# NOTE: Tunable params
# ---------------
# dtu.sip (tile) = processor = PX = tvm.parallel
# PM = 128
# PN = 64
# PK = K
PM = (M + 1) // 2
PN = (N + 1) // 2
PK = K
# dtu.csb (tile) = global/shared (use global) = CX
CM = (PM + 1) // 2
CN = (PN + 1) // 2
CK = (PK + 1) // 2
# dtu.sdb (tile) = local
SM = (CM + 1) // 2
SN = (CN + 1) // 2
SK = (CK + 1) // 2
# dtu.kernel
KM = 2
KN = 2
KK = 2

# ---------------
# define compute
# ---------------
# define a reduce axis
k = te.reduce_axis((0, K), "k") 

# input tensors
l = te.placeholder((M, K), name="l")
r = te.placeholder((K, N), name="r")
# compute
o = te.compute((M, N), lambda m, n: te.sum(l[m, k] * r[k, n], axis=k), name="o")

# ---------------
# schedule op
# ---------------
# create a schedule
s = te.create_schedule(o.op)
print(tvm.lower(s, [l, r, o], simple_mode=True))

## Add DMA

In [None]:

# add d2c
# l_l2 = s.cache_read(l, "global", [o])
# r_l2 = s.cache_read(r, "global", [o])
l_l2 = s.cache_read(l, "shared", [o])
r_l2 = s.cache_read(r, "shared", [o])

# add c2s
# XC - X in cache: global, shared, local
l_l1 = s.cache_read(l_l2, "local", [o])
r_l1 = s.cache_read(r_l2, "local", [o])

# add c2d
o_l2 = s.cache_write(o, "shared")

# add s2c
o_l1 = s.cache_write(o_l2, "local")

print(tvm.lower(s, [l, r, o], simple_mode=True))

## Split output-L3 and output-L2

_PM_，_PN_ 是指每个SIP并行处理的负载量，从一个SIP的视角看，这并没有切分，是其要处理的全部。

_CM_, _CN_ 是 CSB 上的 tiling size

In [None]:

# -------- dependency graph --------
# l -d2c- l_l1 \ 
#               o_l1 (compute) -s2c- o_l2 -c2d- o
# r -d2c- r_l1 /
# we need to determine the stages of a 
# graph from its leaf (o)
# ----------------------------------
# tile - sip - o
# axis
m, n = o.op.axis
# split 
# m
m, pm = s[o].split(m, PM)
pm, cm = s[o].split(pm, CM)
# n
n, pn = s[o].split(n, PN)
pn, cn = s[o].split(pn, CN)
# reorder 
s[o].reorder(m, n, pm, pn, cm, cn)
# insert tensors
# s[l_l2].compute_at(s[o], pn)
# s[r_l2].compute_at(s[o], pn)
# s[l_l1].compute_at(s[o], pn)
# s[r_l1].compute_at(s[o], pn)
# s[o_l1].compute_at(s[o], pn)
s[o_l2].compute_at(s[o], pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))


## 

## Split output-L1

In [None]:

# tile - csb - o_l2
# axis
cm, cn = o_l2.op.axis
# split 
# m
cm, sm = s[o_l2].split(cm, SM)
# n
cn, sn = s[o_l2].split(cn, SN)
# reorder 
s[o_l2].reorder(cm, cn, sm, sn)
# insert tensors
s[o_l1].compute_at(s[o_l2], cn)
print(tvm.lower(s, [l, r, o], simple_mode=True))


## Split output-L0

In [None]:

# tile - sdb & kernel - o_l1
sm, sn = o_l1.op.axis
# split
# m
sm, km = s[o_l1].split(sm, KM)
# n
sn, kn = s[o_l1].split(sn, KN)
# k
pk, ck = s[o_l1].split(k, CK)
ck, sk = s[o_l1].split(ck, SK)
sk, kk = s[o_l1].split(sk, KK)
# reorder 
s[o_l1].reorder(pk, ck, sm, sn, sk, km, kn, kk)

# insert tensors
s[l_l1].compute_at(s[o_l1], ck)
s[r_l1].compute_at(s[o_l1], ck)
s[l_l2].compute_at(s[o_l1], pk)
s[r_l2].compute_at(s[o_l1], pk)
print(tvm.lower(s, [l, r, o], simple_mode=True))


## Parallelize

In [None]:


# parallelize
s[o].parallel(m)
s[o].parallel(n)
print('\n 6 parallelize outer m, n loops \n', tvm.lower(s, [l, r, o], simple_mode=True))

print('X:', [M, N, K])
print('PX:', [PM, PN, PK])
print('CX:', [CM, CN, CK])
print('SX:', [SM, SN, SK])
print('KX:', [KM, KN, KK])