# Matmul Dataflow

这个例子中假设我们的设备信息是：
* (a) 处理器上支持的矩阵乘指令是 _16x16x8_
* (b) L1 tile 需要循环 5x5 次 KernelFunctionCall 
* (c) L2 tile 需要循环 4x4 次 DMA C2S
* (d) L3 tile 需要循环 3x3 次 DMA D2C
* (e) L3 tile 需要循环 2x2 次 DMA D2D


In [None]:
import tvm
from tvm import te 
# ---------------
# op config
# ---------------

# dtu.kernel
KM = 16
KN = 16
KK = 8

# dtu.sdb (tile) = local
SM = KM * 5
SN = KN * 5
SK = KK * 5

# dtu.csb (tile) = global/shared (use global) = CX
CM = SM * 4
CN = SN * 4
CK = SK * 4

# dtu.sip (tile) = processor = PX = tvm.parallel
PM = CM * 3
PN = CN * 3
PK = CK

# feature map
M = PM * 2
N = PN * 2
K = PK

print('*'*64)
print('X: {}'.format([M, N, K]))
print('PX:{}'.format([PM, PN, PK]))
print('CX:{}'.format([CM, CN, CK]))
print('SX:{}'.format([SM, SN, SK]))
print('KX:{}'.format([KM, KN, KK]))
print('*'*64 + '\n')

# ---------------
# define compute
# ---------------
# define a reduce axis
k = te.reduce_axis((0, K), "k") 

# input tensors
l = te.placeholder((M, K), name="l")
r = te.placeholder((K, N), name="r")
# compute
o = te.compute((M, N), lambda m, n: te.sum(l[m, k] * r[k, n], axis=k), name="o")

# ---------------
# schedule op
# ---------------
# create a schedule
s = te.create_schedule(o.op)
print(tvm.lower(s, [l, r, o], simple_mode=True))



## Add New Storage

* `cache_read`: This will mutate the body of the readers. A new cache stage will be created for the tensor. ___Call this before doing any split/fuse schedule___.

In [None]:

# add d2c
# l_l2 = s.cache_read(l, "global", [o])
# r_l2 = s.cache_read(r, "global", [o])
l_l2 = s.cache_read(l, "shared", [o])
r_l2 = s.cache_read(r, "shared", [o])

# add c2s
l_l1 = s.cache_read(l_l2, "local", [o])
r_l1 = s.cache_read(r_l2, "local", [o])

# add c2d
o_l2 = s.cache_write(o, "shared")

# add s2c
o_l1 = s.cache_write(o_l2, "local")

print(tvm.lower(s, [l, r, o], simple_mode=True))

## Tile output and insert output-L2 to output-L3

_PM_，_PN_，_CM_，_CN_ 只是对 output 进行 tiling 的参数，它们并不能描述其具体对应到什么存储层级、并行层级。下面这段代码进行了两个主要的操作：
* 对 `s[o]` 节点进行两次切分，生成 `pm`，`pn`，`cm`，`cn`
* 将 `s[o_l2]` 合并到 `s[o]`的 `pn`，这会触发 pass 对 `s[o_l2]` 做上面同样的两次切分。
* output tlie x2

In [None]:

# axis
m, n = o.op.axis
# tile output firstly
m, n, pm, pn = s[o].tile(m, n, PM, PN)
# tile output secondly
pm, pn, cm, cn = s[o].tile(pm, pn, CM, CN)

s[o_l2].compute_at(s[o], pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))

## 

## Tile output-L2 and insert output-L1 to output-L2

上面的操作是通过对`s[o]`的切分然后 `s[o_l2]` 合并到其中完成了同样的切分，下面的操作是在这个基础上，对`s[o_l2]`再进行一次切分，然后将`s[o_l1]`合并到`s[l_l2]`中同样会自动完成`s[o_l1]`的切分。

这里的操作不会对 `s[o]` 产生影响， 从而实现了不同的切分。

* * output tlie x3


In [None]:
cm, cn = o_l2.op.axis
print(cm)
print(cn)
cm, cn, sm, sn = s[o_l2].tile(cm, cn, SM, SN)

# insert tensors
s[o_l1].compute_at(s[o_l2], cn)
print(tvm.lower(s, [l, r, o], simple_mode=True))


## Tile output-L0

对`s[o_l1]`进行进一步切分，并且将 `s[l_l1]`, `s[l_l2]`, `s[r_l1]` 以及 `s[r_2]` 合并到 `s[o_l1]` 中。

* output tlie x4

In [None]:

# tile - sdb & kernel - o_l1
sm, sn = o_l1.op.axis
print(sm)
print(sn)

sm, sn, km, kn = s[o_l1].tile(sm, sn, KM, KN)

# k
pk, ck = s[o_l1].split(k, CK)
ck, sk = s[o_l1].split(ck, SK)
sk, kk = s[o_l1].split(sk, KK)
# reorder 
s[o_l1].reorder(pk, ck, sm, sn, sk, km, kn, kk)

# insert tensors
s[l_l1].compute_at(s[o_l1], ck)
s[r_l1].compute_at(s[o_l1], ck)
s[l_l2].compute_at(s[o_l1], pk)
s[r_l2].compute_at(s[o_l1], pk)
print(tvm.lower(s, [l, r, o], simple_mode=True))


## Parallelize

In [None]:
print(m)
print(n)

# parallelize
s[o].parallel(m)
s[o].parallel(n)
print(tvm.lower(s, [l, r, o], simple_mode=True))
