In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

## 2.5.1. 第一节：如何编写 TensorIR

### 2.5.1.1. 示例：逐位相加

In [2]:
# init data
a = np.arange(16, dtype=np.int64).reshape(4, 4)
b = np.arange(16, 0, -1, dtype=np.int64).reshape(4, 4)

In [3]:
# numpy version
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]], dtype=int64)

In [4]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]], dtype=int64)

In [5]:
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    def add(A: T.Buffer((4,4), "int64"),
            B: T.Buffer((4,4), "int64"),
            C: T.Buffer((4,4), "int64")):
        T.func_attr({"global_symbol": "add"})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.spatial(4, i)
                vj = T.axis.spatial(4, j)
                C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4,4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

### 2.5.1.2. 练习 1：广播加法

In [6]:
# init data
a = np.arange(16, dtype=np.int64).reshape(4, 4)
b = np.arange(4, 0, -1, dtype=np.int64).reshape(4)

In [7]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]], dtype=int64)

In [8]:
def broadcast_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[j]

c_lnumpy = np.empty((4, 4), dtype=np.int64)
broadcast_add(a, b, c_lnumpy)
c_lnumpy

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]], dtype=int64)

In [9]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4,4), "int64"),
          B: T.Buffer((4), "int64"),
          C: T.Buffer((4,4), "int64")):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

### 2.5.1.3. 练习 2：二维卷积

In [10]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W, dtype=np.int64).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K, dtype=np.int64).reshape(CO, CI, K, K)

In [11]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]], dtype=int64)

In [12]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(data: T.Buffer((N, CI, H, W), "int64"),
           weight: T.Buffer((CO, CI, K, K), "int64"),
           conv: T.Buffer((N, CO, OUT_H, OUT_W), "int64")):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for n, co, ho, wo, ci, di, dj in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
      with T.block("CONV"):
        vn, vco, vho, vwo, vci, vdi, vdj = T.axis.remap("SSSSRRR", [n, co, ho, wo, ci, di, dj])
        with T.init():
          conv[vn, vco, vho, vwo] = T.int64(0)
        conv[vn, vco, vho, vwo] = conv[vn, vco, vho, vwo] + data[vn, vci, vho+vdi, vwo+vdj] * weight[vco, vci, vdi, vdj]

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## 2.5.2. 第二节：如何变换 TensorIR

### 2.5.2.1. 并行化、向量化与循环展开

In [13]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4, 4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
sch.mod.show()
# IPython.display.Code(sch.mod.script(), language="python")

### 2.5.2.2. 练习 3：变换批量矩阵乘法程序

In [14]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [15]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.buffer((16, 128, 128), "float32"),
               B: T.buffer((16, 128, 128), "float32"),
               C: T.buffer((16, 128, 128), "float32"),):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
        with T.init():
          Y[vn, vi, vj] = T.float32(0)
        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        vn, vi, vj = T.axis.remap("SSS", [n, i, j])
        C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
sch = tvm.tir.Schedule(MyBmmRelu)
sch.mod.show()
# IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

In [16]:
stride = 0.03
a = np.arange(0, 16*128*128*stride, stride, dtype="float32").reshape(16, 128, 128)
b = np.arange(16*128*128*stride, 0, -stride, dtype="float32").reshape(16, 128, 128)
c_lnumpy = np.empty((16, 128, 128), dtype=np.float32)
lnumpy_mm_relu_v2(a, b , c_lnumpy)
c_lnumpy

array([[[1.83861612e+06, 1.83860925e+06, 1.83860212e+06, ...,
         1.83770850e+06, 1.83770088e+06, 1.83769388e+06],
        [5.58509200e+06, 5.58507100e+06, 5.58504900e+06, ...,
         5.58235450e+06, 5.58233250e+06, 5.58231100e+06],
        [9.33157000e+06, 9.33153200e+06, 9.33149400e+06, ...,
         9.32700100e+06, 9.32696600e+06, 9.32692800e+06],
        ...,
        [4.70148352e+08, 4.70146368e+08, 4.70144576e+08, ...,
         4.69918432e+08, 4.69916672e+08, 4.69914848e+08],
        [4.73894656e+08, 4.73892832e+08, 4.73891008e+08, ...,
         4.73663264e+08, 4.73661088e+08, 4.73659552e+08],
        [4.77641152e+08, 4.77639328e+08, 4.77637344e+08, ...,
         4.77407872e+08, 4.77405984e+08, 4.77404064e+08]],

       [[4.50566240e+08, 4.50564320e+08, 4.50562464e+08, ...,
         4.50331072e+08, 4.50329216e+08, 4.50327328e+08],
        [4.54072992e+08, 4.54071008e+08, 4.54069216e+08, ...,
         4.53835936e+08, 4.53833952e+08, 4.53832160e+08],
        [4.57579488e+08, 

In [17]:
rt_lib = tvm.build(MyBmmRelu, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((16, 128, 128), dtype=np.float32))
rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_lnumpy, rtol=1e-5)

#### 目标程序

In [18]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), B: T.Buffer((16, 128, 128), "float32"), C: T.Buffer((16, 128, 128), "float32")) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

### Version 1

In [19]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
block_Y = sch.get_block("Y", func_name="bmm_relu")
block_C = sch.get_block("C", func_name="bmm_relu")
# Step 2. Get loops
n, i, j, k = sch.get_loops(block_Y)
# Step 3. Organize the loops
j_0, j_1 = sch.split(j, factors=[None, 8])
k_0, k_1 = sch.split(k, factors=[None, 4])
sch.reorder(k_0, k_1, j_1)

# Step 5. vectorize / parallel / unroll
sch.parallel(n)
sch.unroll(k_1)
sch.vectorize(j_1)

sch.reverse_compute_at(block_C, j_0)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(block_Y, k_0)
n, i, j0, ax0 = sch.get_loops(block_C)
sch.vectorize(ax0)
sch.mod.show()

In [20]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

TVMError: Traceback (most recent call last):
  File "C:\Users\hjs\tvm\src\node\structural_equal.cc", line 376
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("bmm_relu")].body.block.body.body.body.body.seq[1].body.body.<unknown attribute>:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": T.bool(True)})
        A = T.match_buffer(A_handle, (16, 128, 128))
        B = T.match_buffer(B_handle, (16, 128, 128))
        C = T.match_buffer(C_handle, (16, 128, 128))
        with T.block("root"):
            T.reads()
            T.writes()
            Y = T.alloc_buffer((16, 128, 128))
            for n in T.parallel(16):
                for i in range(128):
                    for j_0 in range(16):
                        for j_1_init in T.vectorized(8):
                            with T.block("Y_init"):
                                vn = T.axis.spatial(16, n)
                                vi = T.axis.spatial(128, i)
                                vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                                T.reads()
                                T.writes(Y[vn, vi, vj])
                                Y[vn, vi, vj] = T.float32(0)
                        for k_0 in range(32):
                            for k_1 in T.unroll(4):
                                for j_1 in T.vectorized(8):
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    with T.block("Y_update"):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^
                                        vn = T.axis.spatial(16, n)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        vi = T.axis.spatial(128, i)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        vk = T.axis.reduce(128, k_0 * 4 + k_1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.reads(Y[vn, vi, vj], A[vn, vi, vk], B[vn, vk, vj])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.writes(Y[vn, vi, vj])
                                        ^^^^^^^^^^^^^^^^^^^^^^^
                                        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax0 in T.vectorized(8):
                            with T.block("C"):
                                vn = T.axis.spatial(16, n)
                                vi = T.axis.spatial(128, i)
                                vj = T.axis.spatial(128, j_0 * 8 + ax0)
                                T.reads(Y[vn, vi, vj])
                                T.writes(C[vn, vi, vj])
                                C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
and rhs at <root>.functions[I.GlobalVar("bmm_relu")].body.block.body.body.body.body.seq[1].body.body.<unknown attribute>:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": T.bool(True)})
        A = T.match_buffer(A_handle, (16, 128, 128))
        B = T.match_buffer(B_handle, (16, 128, 128))
        C = T.match_buffer(C_handle, (16, 128, 128))
        with T.block("root"):
            T.reads()
            T.writes()
            Y = T.alloc_buffer((16, 128, 128))
            for i0 in T.parallel(16):
                for i1 in range(128):
                    for i2_0 in range(16):
                        for ax0_init in T.vectorized(8):
                            with T.block("Y_init"):
                                n = T.axis.spatial(16, i0)
                                i = T.axis.spatial(128, i1)
                                j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                                T.reads()
                                T.writes(Y[n, i, j])
                                Y[n, i, j] = T.float32(0)
                        for ax1_0 in range(32):
                            for ax1_1 in T.unroll(4):
                                for ax0 in range(8):
                                ^^^^^^^^^^^^^^^^^^^^
                                    with T.block("Y_update"):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^
                                        n = T.axis.spatial(16, i0)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        i = T.axis.spatial(128, i1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.reads(Y[n, i, j], A[n, i, k], B[n, k, j])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.writes(Y[n, i, j])
                                        ^^^^^^^^^^^^^^^^^^^^
                                        Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for i2_1 in T.vectorized(8):
                            with T.block("C"):
                                n = T.axis.spatial(16, i0)
                                i = T.axis.spatial(128, i1)
                                j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                                T.reads(Y[n, i, j])
                                T.writes(C[n, i, j])
                                C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

### Version 2
若把 sch.parallel(n) 放在 decompose reduction 之后(按照exercice顺序)会报错

答疑见：https://github.com/mlc-ai/mlc-zh/discussions/35

原因是 decompose reduction 之后，三个 block 全都既不是 local complete block 也不是 local reduction block。具体因为什么条件不满足，之后有机会再仔细看吧(TODO)

所以要把 sch.parallel(n) 放在 decompose reduction 之前

另：
- 本题解答：https://github.com/mlc-ai/mlc-zh/discussions/145
- 中间变量y是否有必要：https://github.com/mlc-ai/mlc-zh/discussions/32
- parallel/unroll/vectorize的区别：https://github.com/mlc-ai/mlc-zh/discussions/82

In [21]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
block_Y = sch.get_block("Y", func_name="bmm_relu")
block_C = sch.get_block("C", func_name="bmm_relu")
# Step 2. Get loops
n, i, j, k = sch.get_loops(block_Y)
# Step 3. Organize the loops
j_0, j_1 = sch.split(j, factors=[None, 8])
k_0, k_1 = sch.split(k, factors=[None, 4])
sch.reorder(k_0, k_1, j_1)
sch.reverse_compute_at(block_C, j_0)

sch.parallel(n)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(block_Y, k_0)

# Step 5. vectorize / parallel / unroll
Y_init = sch.get_block("Y_init", func_name="bmm_relu")
n, i, j0, j1 = sch.get_loops(Y_init)
sch.vectorize(j1)

sch.unroll(k_1)

C = sch.get_block("C", func_name="bmm_relu")
n, i, j0, j1 = sch.get_loops(C)
sch.vectorize(j1)

sch.mod.show()

In [22]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


In [23]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  44.0825      44.0825      44.0825      44.0825       0.0000                  
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   6.2044       6.2044       6.2044       6.2044       0.0000                  
