# Matmul Template 111

我们先预热一下两个简单的Case，他们是从一个SIP的视角出发，对L2和L1上的数据搬运的Loop次数进行计数。
```C++
  int func(std::array<int, 3> &loop_times) {
    int sum = 0;
    for (int i = 0 ; i < loop_times.size(); ++i) {
      if (loop_times[i] > 1) { sum += 1; }
    }
    if (sum ==0) { sum = 1; }
    return sum;
  }
```


|Hierarchy|LHS|RHS|OUT|
|---------|---|---|---|
|L3       | 1 | 1 | 1 |
|L2       | 1 | 1 | 1 |
|L1       | 1 | 1 | 1 |

TODO：怎么表达会比较清晰呢？


## Implement1: no split reduction axis
不切分 reduce axis 的实现。

### Implement1 orignal dataflow


In [None]:
import tvm
from tvm import te 
# ---------------
# op config
# ---------------
M = 256
N = 128
K = 512
# ---------------
# tile size
# NOTE: Tunable params
# ---------------
# dtu.sip (tile) = processor = PX = tvm.parallel
PM = M // 2      # 128
PN = N // 2      # 64
PK = K           # 512

# dtu.kernel
KM = 16          # 16
KN = 8           # 8 
KK = 32          # 32


# ---------------
# define compute
# ---------------
def matmul(dump=False):
  # define a reduce axis
  k = te.reduce_axis((0, K), "k") 

  # input tensors
  l = te.placeholder((M, K), name="l")
  r = te.placeholder((K, N), name="r")
  # compute
  o = te.compute((M, N), lambda m, n: te.sum(l[m, k] * r[k, n], axis=k), name="o")
  # ---------------
  # schedule op
  # ---------------
  # create a schedule
  s = te.create_schedule(o.op)
  if dump:
    print(tvm.lower(s, [l, r, o], simple_mode=True))
  
  return k, l, r, o, s

k, l, r, o, s = matmul(True)



### Implement1: Add DMA Ops


In [None]:
# add d2c
# l_l2 = s.cache_read(l, "global", [o])
# r_l2 = s.cache_read(r, "global", [o])
l_l2 = s.cache_read(l, "shared", [o])
r_l2 = s.cache_read(r, "shared", [o])

# add c2s
# XC - X in cache: global, shared, local
l_l1 = s.cache_read(l_l2, "local", [o])
r_l1 = s.cache_read(r_l2, "local", [o])

# add c2d
o_l2 = s.cache_write(o, "shared")

# add s2c
o_l1 = s.cache_write(o_l2, "local")

print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Tile output-L1

In [None]:
# get axes of o
m, n = o_l1.op.axis # o has no k dim now

# tile m and n for sip 
m, n, pm, pn = s[o_l1].tile(m, n, PM, PN)

## tile equal splitx2 + reorder
# m, pm = s[o_l1].split(m, PM)
# n, pn = s[o_l1].split(n, PN)
# s[o_l1].reorder(m, n, pm, pn)

print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Insert Inputs-L1 to Output-L1

In [None]:
# insert l_l1 & r_l1 to o_l1
s[l_l1].compute_at(s[o_l1], pn)
s[r_l1].compute_at(s[o_l1], pn)
# print(tvm.lower(s, [l, r, o], simple_mode=True))
# insert l_l2 & r_l2 to o_l1
s[l_l2].compute_at(s[o_l1], pn)
s[r_l2].compute_at(s[o_l1], pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Tile output-L2

In [None]:
# tile o_l2 
# get axes of o
m, n = o_l2.op.axis # o has no k dim now
m, n, pm, pn = s[o_l2].tile(m, n, PM, PN)
## split m and n for sip 
# m, pm = s[o_l2].split(m, PM)
# n, pn = s[o_l2].split(n, PN)
# s[o_l2].reorder(m, n, pm, pn)

print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Insert output-L1 to output-L2

In [None]:
# insert o_l1 to o_l2
s[o_l1].compute_at(s[o_l2], n)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Tile output-L3

In [None]:

# tile o
# get axes of o
m, n = o.op.axis # o has no k dim now

m, n, pm, pn = s[o].tile(m, n, PM, PN)
### split m and n for sip 
# m, pm = s[o].split(m, PM)
# n, pn = s[o].split(n, PN)
# s[o].reorder(m, n, pm, pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement1: Insert output-L2 to output-L3

In [None]:
s[o_l2].compute_at(s[o], n)
print(tvm.lower(s, [l, r, o], simple_mode=True))


### Implement1: Parallelize

In [None]:
# parallelize outer m, n loops
s[o].parallel(m)
s[o].parallel(n)
print(tvm.lower(s, [l, r, o], simple_mode=True))
print('Implement1 Finished')

## Implement2

### Implement2: Split output-L1

在Implement1基础上添加了Kernel层，并且在Kernel层实现了对reduce axis的切分。

### Implement2: Insert Inputs-L1 and Inputs-L2 to output-L1

In [None]:
k, l, r, o, s = matmul()

l_l2 = s.cache_read(l, "shared", [o])
r_l2 = s.cache_read(r, "shared", [o])

# add c2s
# XC - X in cache: global, shared, local
l_l1 = s.cache_read(l_l2, "local", [o])
r_l1 = s.cache_read(r_l2, "local", [o])

# add c2d
o_l2 = s.cache_write(o, "shared")

# add s2c
o_l1 = s.cache_write(o_l2, "local")

m, n = o_l1.op.axis # o has no k dim now

# insert l_l1 & r_l1 to o_l1
s[l_l1].compute_at(s[o_l1], n)
s[r_l1].compute_at(s[o_l1], n)
# insert l_l2 & r_l2 to o_l1
s[l_l2].compute_at(s[o_l1], n)
s[r_l2].compute_at(s[o_l1], n)
print(tvm.lower(s, [l, r, o], simple_mode=True))



### Implement2: Split output-L2

In [None]:

# tile o_l2 
# get axes of o
m, n = o_l2.op.axis # o has no k dim now

# split m and n for sip 
m, pm = s[o_l2].split(m, PM)
n, pn = s[o_l2].split(n, PN)
# reorder 
s[o_l2].reorder(m, n, pm, pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))


### Implement2: Insert output-L1 to output-L2

In [None]:

# insert o_l1 to o_l2
s[o_l1].compute_at(s[o_l2], n)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement2: Split output-L3

In [None]:
# tile o
# get axes of o
m, n = o.op.axis # o has no k dim now

# split m and n for sip 
m, pm = s[o].split(m, PM)
n, pn = s[o].split(n, PN)
# reorder 
s[o].reorder(m, n, pm, pn)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement2: Insert output-L2 to output-L3

In [None]:
s[o_l2].compute_at(s[o], n)
print(tvm.lower(s, [l, r, o], simple_mode=True))

### Implement2: Parallel

In [None]:
s[o].parallel(m)
s[o].parallel(n)
print(tvm.lower(s, [l, r, o], simple_mode=True))