## 第一节：如何编写 TensorIR
在本节中，让我们尝试根据高级指令（例如 Numpy 或 Torch）手动编写 TensorIR。首先，我们给出一个逐位相加函数的例子，来展示我们应该如何编写一个 TensorIR 函数。

示例：逐位相加
首先，让我们尝试使用 Numpy 编写一个逐位相加函数。

In [2]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
import IPython

# init data
dtype = "float32"
a = np.arange(16).reshape(4, 4).astype(dtype)
b = np.arange(16, 0, -1).reshape(4, 4).astype(dtype)

# numpy version
c_np = a + b
c_np

array([[16., 16., 16., 16.],
       [16., 16., 16., 16.],
       [16., 16., 16., 16.],
       [16., 16., 16., 16.]], dtype=float32)

在我们直接编写 TensorIR 之前，我们应该首先将高级计算抽象（例如，ndarray + ndarray）转换为低级 Python 实现（具有元素访问和操作的循环的标准）。

值得注意的是，输出数组（或缓冲区）的初始值并不总是 0。我们需要在我们的实现中编写或初始化它，这对于归约运算符（例如 matmul 和 conv）很重要。

In [10]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype="float32")
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16., 16., 16., 16.],
       [16., 16., 16., 16.],
       [16., 16., 16., 16.],
       [16., 16., 16., 16.]], dtype=float32)

In [11]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "float32"],
          B: T.Buffer[(4, 4), "float32"],
          C: T.Buffer[(4, 4), "float32"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype="float32"))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

这里原本示例上给的元素dtype均为int64，然而运行时总会出现assertion error，改为float32后可以正常运行。在discussion中也有同学提到。

到这里，我们就完成了 TensorIR 函数。请花点时间完成以下练习。

练习 1：广播加法
请编写一个 TensorIR 函数，将两个数组以广播的方式相加。

In [13]:
# init data
dtype = "float32"
a = np.arange(16).reshape(4, 4).astype(dtype)
b = np.arange(4, 0, -1).reshape(4).astype(dtype)
# numpy version
c_np = a + b
a, b, c_np

(array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], dtype=float32),
 array([4., 3., 2., 1.], dtype=float32),
 array([[ 4.,  4.,  4.,  4.],
        [ 8.,  8.,  8.,  8.],
        [12., 12., 12., 12.],
        [16., 16., 16., 16.]], dtype=float32))

请完成以下 IRModule MyAdd 并运行代码以检查你的实现。

In [16]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "float32"],
          B: T.Buffer[(4)   , "float32"],
          C: T.Buffer[(4, 4), "float32"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    # TODO
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        
        C[i, j] = A[i, j] + B[j]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype="float32"))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

练习 2：二维卷积
然后，让我们尝试做一些具有挑战性的事情：二维卷积。这是图像处理中的常见操作。

这是使用 NCHW 布局的卷积的数学定义：

$$
Conv[b, k, i, j] = \sum_{di, dj, q} A[b, q, strides * i + di, strides * j + dj] * W[k, q, di, dj],
Conv[b,k,i,j]= 
di,dj,q
∑
​
 A[b,q,strides∗i+di,strides∗j+dj]∗W[k,q,di,dj],
$$
其中，A 是输入张量，W 是权重张量，b 是批次索引，k 是输出通道，i 和 j 是图像高度和宽度的索引，di 和 dj 是权重的索引，q 是输入通道，strides 是过滤器窗口的步幅。

在练习中，我们选择了一个小而简单的情况，即 stride=1, padding=0。

In [23]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W).astype("float32")
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K).astype("float32")
# torch version
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype("float32")
data, data.shape, weight, weight.shape, conv_torch, conv_torch.shape

(array([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11., 12., 13., 14., 15.],
          [16., 17., 18., 19., 20., 21., 22., 23.],
          [24., 25., 26., 27., 28., 29., 30., 31.],
          [32., 33., 34., 35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44., 45., 46., 47.],
          [48., 49., 50., 51., 52., 53., 54., 55.],
          [56., 57., 58., 59., 60., 61., 62., 63.]]]], dtype=float32),
 (1, 1, 8, 8),
 array([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]]],
 
 
        [[[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]]], dtype=float32),
 (2, 1, 3, 3),
 array([[[[ 474.,  510.,  546.,  582.,  618.,  654.],
          [ 762.,  798.,  834.,  870.,  906.,  942.],
          [1050., 1086., 1122., 1158., 1194., 1230.],
          [1338., 1374., 1410., 1446., 1482., 1518.],
          [1626., 1662., 1698., 1734., 1770., 1806.],
          [1914., 1950., 1986., 2022., 2058., 2094.]],
 
         [[1203., 132

请完成以下 IRModule MyConv 并运行代码以检查您的实现。

In [27]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(Data:  T.Buffer[(1, 1, 8, 8), "float32"],
           Weight:T.Buffer[(2, 1, 3, 3), "float32"],
           Conv:  T.Buffer[(1, 2, 6, 6), "float32"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    # TODO
    strides = 1
    padding = 0
    for b, k, i, j, di, dj, q in T.grid(1, 2, 6, 6, 3, 3, 1):
      with T.block("Conv"):
        v_b = T.axis.spatial(1 ,b)
        v_k = T.axis.spatial(2, k)
        v_i = T.axis.spatial(6, i)
        v_j = T.axis.spatial(6, j)
        
        v_di = T.axis.reduce(3, di)
        v_dj = T.axis.reduce(3, dj)
        v_q  = T.axis.reduce(1, q)
        # 注意此处，我们在b,k,i,j这个坐标上对di,dj,q三个归约轴求和。
        # 实现上是通过循环来累加，因此需要对结果数组进行初始化
        with T.init():
          Conv[v_b, v_k, v_i, v_j] = T.float32(0)
        Conv[v_b, v_k, v_i, v_j] += Data[v_b, v_q, strides*v_i+v_di, strides*v_j+v_dj] * Weight[v_k, v_q, v_di, v_dj]

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype="float32"))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

第二节：如何变换 TensorIR

在讲座中，我们了解到 TensorIR 不仅是一种编程语言，而且还是一种程序变换的抽象。在本节中，让我们尝试变换程序。我们在采用了 bmm_relu (batched_matmul_relu)，这是一种常见于 Transformer 等模型中的操作变体。

并行化、向量化与循环展开
首先，我们介绍一些新的原语：parallel、vectorize 和 unroll。这三个原语被应用于循环上，指示循环应当如何执行。这是示例：

In [7]:
def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.HTML(code2html(sch.mod.script()))

Parallel相当于把循环底下的执行部分做成了一个线程，提供线程级并行

vectorize是SIMD，即使用一条指令操纵多条数据。

unroll是循环展开（传统编译器常见优化），loop unroll是经典的提高IPC的方法之一。

练习 3：变换批量矩阵乘法程序
现在，让我们回到 bmm_relu 练习。首先，让我们看看 bmm 的定义：
$$
Y_{n, i, j} = \sum_k A_{n, i, k} \times B_{n, k, j}Y 
n,i,j
​
 =∑ 
k
​
 A 
n,i,k
​
 ×B 
n,k,j
​
 \\
C_{n, i, j} = \mathbb{relu}(Y_{n,i,j}) = \mathbb{max}(Y_{n, i, j}, 0)C 
n,i,j
​
 =relu(Y 
n,i,j
​
 )=max(Y 
n,i,j
​
 ,0)
 $$
现在是你为 bmm_relu 编写 TensorIR 的时候了。我们提供 lnumpy 函数作为提示：

In [3]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [27]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], 
               B: T.Buffer[(16, 128, 128), "float32"], 
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    # TODO
    Y = T.alloc_buffer([16, 128, 128], dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        v_n = T.axis.spatial(16,  n)
        v_i = T.axis.spatial(128, i)
        v_j = T.axis.spatial(128, j)
        v_k = T.axis.reduce (128, k)

        with T.init():
          Y[v_n, v_i, v_j] = T.float32(0)
        Y[v_n, v_i, v_j] += A[v_n, v_i, v_k] * B[v_n, v_k, v_j]

    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        v_n = T.axis.spatial(16,  n)
        v_i = T.axis.spatial(128, i)
        v_j = T.axis.spatial(128, j)

        C[v_n, v_i, v_j] = T.max(T.float32(0), Y[v_n, v_i, v_j])


sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.HTML(code2html(sch.mod.script()))
# Also please validate your result
rt_lib = tvm.build(MyBmmRelu, target="llvm")
A = np.ones(16*128*128).reshape(16, 128, 128).astype("float32")
B = np.ones(16*128*128).reshape(16, 128, 128).astype("float32")
C = np.empty((16, 128, 128), dtype="float32")

lnumpy_mm_relu_v2(A, B, C)

A_tvm = tvm.nd.array(A)
B_tvm = tvm.nd.array(B)
C_tvm = tvm.nd.array(np.empty((16, 128, 128), dtype="float32"))
rt_lib["bmm_relu"](A_tvm, B_tvm, C_tvm)
# 注意tvm的array类型numpy无法识别，会报TypeError，因此要转成numpy类型才可以
np.testing.assert_allclose(C, C_tvm.numpy(), rtol=1e-5)


在本练习中，让我们专注于将原始程序变换为特定目标。请注意，由于硬件不同，目标程序可能不是最好的程序。但是这个练习旨在让你了解如何将程序变换为想要的程序。 这是目标程序：

In [37]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

你的任务是将原始程序转换为目标程序。

In [38]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
c_n, c_i, c_j = sch.get_loops(C)

# Step 3. Organize the loops
j_0, j_1 = sch.split(j, factors=[16, 8])
k_0, k_1 = sch.split(k, factors=[32, 4])
sch.reorder(n, i, j_0, k_0, k_1, j_1)
IPython.display.HTML(code2html(sch.mod.script()))

sch.reverse_compute_at(C, j_0)
# ...
# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k_0)
# ...
IPython.display.HTML(code2html(sch.mod.script()))
# # Step 5. vectorize / parallel / unroll
# sch.vectorize(j_1)
# sch.parallel(n)
# sch.unroll(...)
# ...

# IPython.display.HTML(code2html(sch.mod.script()))

（可选） 如果我们想确保变换后的程序与给定的目标完全相同，我们可以使用 assert_structural_equal。请注意，此步骤是本练习中的可选步骤。 如果您将程序朝着目标转变并获得性能提升，这就足够了。

In [39]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

TVMError: Traceback (most recent call last):
  File "D:\a\utils\utils\tvm\src\node\structural_equal.cc", line 123
ValueError: StructuralEqual check failed, caused by lhs:
for (n: int32, 0, 16) {
  for (i: int32, 0, 128) {
    for (j_0: int32, 0, 16) {
      for (j_1_init: int32, 0, 8) {
        block([16, 128, 128], "Y_init") as [v_n, v_i, v_j] {
          bind(v_n, n)
          bind(v_i, i)
          bind(v_j, ((j_0*8) + j_1_init))
          tir.reads([])
          tir.writes([Y: Buffer(Y_1: Pointer(global float32), float32, [16, 128, 128], [])[v_n, v_i, v_j]])
          Y[v_n, v_i, v_j] = 0f32
      }
      for (k_0: int32, 0, 32) {
        for (k_1: int32, 0, 4) {
          for (j_1: int32, 0, 8) {
            block([16, 128, 128, tir.reduce_axis(0, 128)], "Y_update") as [v_n_1, v_i_1, v_j_1, v_k] {
              bind(v_n_1, n)
              bind(v_i_1, i)
              bind(v_j_1, ((j_0*8) + j_1))
              bind(v_k, ((k_0*4) + k_1))
              tir.reads([Y[v_n_1, v_i_1, v_j_1], A: Buffer(A_1: Pointer(global float32), float32, [16, 128, 128], [])[v_n_1, v_i_1, v_k], B: Buffer(B_1: Pointer(global float32), float32, [16, 128, 128], [])[v_n_1, v_k, v_j_1]])
              tir.writes([Y[v_n_1, v_i_1, v_j_1]])
              Y[v_n_1, v_i_1, v_j_1] = (Y[v_n_1, v_i_1, v_j_1] + (A[v_n_1, v_i_1, v_k]*B[v_n_1, v_k, v_j_1]))
          }
        }
      }
      for (ax0: int32, 0, 8) {
        block([16, 128, 128], "C") as [v_n_2, v_i_2, v_j_2] {
          bind(v_n_2, n)
          bind(v_i_2, i)
          bind(v_j_2, ((j_0*8) + ax0))
          tir.reads([Y[v_n_2, v_i_2, v_j_2]])
          tir.writes([C: Buffer(C_1: Pointer(global float32), float32, [16, 128, 128], [])[v_n_2, v_i_2, v_j_2]])
          C[v_n_2, v_i_2, v_j_2] = max(0f32, Y[v_n_2, v_i_2, v_j_2])
      }
    }
  }
}
and rhs:
for (i0: int32, 0, 16) "parallel" {
  for (i1: int32, 0, 128) {
    for (i2_0: int32, 0, 16) {
      for (ax0_init: int32, 0, 8) "vectorized" {
        block([16, 128, 128], "Y_init") as [n, i, j] {
          bind(n, i0)
          bind(i, i1)
          bind(j, ((i2_0*8) + ax0_init))
          tir.reads([])
          tir.writes([Y: Buffer(Y_1: Pointer(global float32), float32, [16, 128, 128], [])[n, i, j]])
          Y[n, i, j] = 0f32
      }
      for (ax1_0: int32, 0, 32) {
        for (ax1_1: int32, 0, 4) "unroll" {
          for (ax0: int32, 0, 8) {
            block([16, 128, 128, tir.reduce_axis(0, 128)], "Y_update") as [n_1, i_1, j_1, k] {
              bind(n_1, i0)
              bind(i_1, i1)
              bind(j_1, ((i2_0*8) + ax0))
              bind(k, ((ax1_0*4) + ax1_1))
              tir.reads([Y[n_1, i_1, j_1], A: Buffer(A_1: Pointer(global float32), float32, [16, 128, 128], [])[n_1, i_1, k], B: Buffer(B_1: Pointer(global float32), float32, [16, 128, 128], [])[n_1, k, j_1]])
              tir.writes([Y[n_1, i_1, j_1]])
              Y[n_1, i_1, j_1] = (Y[n_1, i_1, j_1] + (A[n_1, i_1, k]*B[n_1, k, j_1]))
          }
        }
      }
      for (i2_1: int32, 0, 8) "vectorized" {
        block([16, 128, 128], "C") as [n_2, i_2, j_2] {
          bind(n_2, i0)
          bind(i_2, i1)
          bind(j_2, ((i2_0*8) + i2_1))
          tir.reads([Y[n_2, i_2, j_2]])
          tir.writes([C: Buffer(C_1: Pointer(global float32), float32, [16, 128, 128], [])[n_2, i_2, j_2]])
          C[n_2, i_2, j_2] = max(Y[n_2, i_2, j_2], 0f32)
      }
    }
  }
}

In [40]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  20.4089      20.4089      20.4089      20.4089       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   2.4941       2.4941       2.4941       2.4941       0.0000   
               
