In [78]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [102]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def test1(A: T.Buffer[(1024), "float32"],
                B: T.Buffer[(1024), "float32"]):
        T.func_attr({"global_symbol": "test1", "tir.noalias": True})
        for i, j in T.grid(1024, 1024):
            with T.block("Y"):
                vi = T.axis.spatial(1024, i)
                vj = T.axis.spatial(1024, j)
                A[vi] = A[vi] + B[vj]

In [111]:
sch = tvm.tir.Schedule(MyModule)
IPython.display.Code(sch.mod.script(), language="python")

假设源代码是如下:
```
DO I = 1, N
  DO J = 1, M
    A(I) = A(I) + B(J)  
  END DO
END DO
```
假设cache line大小为b
则原函数的A的cache miss次数N/b，而B的cache miss次数为N*M/b.
总miss次数为N/b + N*M/b

In [112]:
a_np = np.zeros(1024).astype(dtype)
b_np = np.random.rand(1024).astype(dtype)

In [125]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
rt_lib = tvm.build(MyModule, target="llvm")
dtype = "float32"
func_mm_relu = rt_lib["test1"]
func_mm_relu(a_nd, b_nd)

In [126]:
a_nd

<tvm.nd.NDArray shape=(1024,), cpu(0)>
array([510.15686, 510.15686, 510.15686, ..., 510.15686, 510.15686,
       510.15686], dtype=float32)

In [157]:
def lnumpy_test1(A: np.ndarray, B: np.ndarray):
    for i in range(1024):
        for j in range(1024):
            A[i] = A[i] + B[j]

In [158]:
f_timer_before = rt_lib.time_evaluator("test1", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd).mean)

Time cost of MyModule 0.00137204 sec


In [170]:
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="test1")
i, j = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
sch.reorder(j0, i)

In [171]:
IPython.display.Code(sch.mod.script(), language="python")

```
DO J = 1, M, T
  DO I = 1, N
    DO jj = J, min(J+T-1, M)
      A(I) = A(I) + B(jj)
    END DO
  END DO
END DO
```
该变换又称为strip-mine-and-interchange.
则函数中A的cache miss次数M*N/(bT)，而B的cache miss次数为M/T(当T等于b的时候，J每循环一次才miss一次).
总miss次数为MN/(bT) + M/b（作为对比，原miss次N/b + N*M/b）

In [172]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["test1"](a_nd, b_nd)
a_nd

<tvm.nd.NDArray shape=(1024,), cpu(0)>
array([510.15686, 510.15686, 510.15686, ..., 510.15686, 510.15686,
       510.15686], dtype=float32)

In [179]:
f_timer_after = rt_lib_after.time_evaluator("test1", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd).mean)

Time cost of transformed sch.mod 5.27739e-05 sec


参考：https://zhuanlan.zhihu.com/p/292539074