In [66]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [67]:
N = 4096
M = 4096

In [68]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def reduce(D: T.Buffer[(N), "float32"],
                B: T.Buffer[(N, M), "float32"]):
        T.func_attr({"global_symbol": "reduce", "tir.noalias": True})
        for j, i in T.grid(M, N):
            with T.block("Y"):
                vi = T.axis.spatial(N, i)
                vj = T.axis.spatial(M, j)
                D[vi] = D[vi] + B[vi, vj]

```
DO J = 1, M
  DO I = 1, N
    D(I) = D(I) + B(I, J)
  END DO
END DO
```
D的cache miss是M*N/b, B的cache miss是M*N


In [69]:
dtype = "float32"
d_np = np.zeros(N).astype(dtype)
b_np = np.random.rand(N, M).astype(dtype)

In [70]:
d_nd = tvm.nd.array(d_np)
b_nd = tvm.nd.array(b_np)
rt_lib = tvm.build(MyModule, target="llvm")
func_reduce = rt_lib["reduce"]
func_reduce(d_nd, b_nd)
d_nd

<tvm.nd.NDArray shape=(4096,), cpu(0)>
array([2077.4783, 2037.7194, 2062.459 , ..., 2023.0256, 2077.0017,
       2049.1416], dtype=float32)

In [71]:
f_timer_before = rt_lib.time_evaluator("reduce", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(d_nd, b_nd).mean)

Time cost of MyModule 0.20495 sec


In [72]:
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="reduce")
j, i = sch.get_loops(block_Y)
sch.reorder(i, j)
IPython.display.Code(sch.mod.script(), language="python")

In [73]:
d_nd = tvm.nd.array(d_np)
b_nd = tvm.nd.array(b_np)
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["reduce"](d_nd, b_nd)
d_nd

<tvm.nd.NDArray shape=(4096,), cpu(0)>
array([2077.4783, 2037.7194, 2062.459 , ..., 2023.0256, 2077.0017,
       2049.1416], dtype=float32)

In [74]:
f_timer_after = rt_lib_after.time_evaluator("reduce", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(d_nd, b_nd).mean)

Time cost of transformed sch.mod 0.0189379 sec


```
DO I = 1, N
  DO J = 1, M
    D(I) = D(I) + B(I, J)
  END DO
END DO
```
D的cache miss为N/b, B的cache miss是M*N/b。