简单练习

作为练习，尝试不同的 j_factor 选择，看看它们如何影响代码的性能。

In [2]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
import IPython

In [3]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with T.block("root")
        Y = T.alloc_buffer([128, 128], dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [4]:
def transform(mod, jfactor):
    sch = tvm.tir.Schedule(mod)
    block_Y = sch.get_block("Y", func_name="mm_relu")
    i, j, k = sch.get_loops(block_Y)
    j0, j1 = sch.split(j, factors=[None, jfactor])
    sch.reorder(j0, k, j1)
    block_C = sch.get_block("C", "mm_relu")
    sch.reverse_compute_at(block_C, j0)
    return sch.mod

mod_transformed = transform(MyModule, jfactor=8)

dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")

rt_lib_transformed = tvm.build(mod_transformed, "llvm")
f_timer_transformed = rt_lib_transformed.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed mod_transformed %g sec" % f_timer_transformed(a_nd, b_nd, c_nd).mean)
# display the code below
IPython.display.Code(mod_transformed.script(), language="python")

Time cost of transformed mod_transformed 0.000471282 sec


# TensorIR练习

- 如何编写
    - 示例：逐位相加
    - 1.广播加法
    - 2.二维卷积

- 如何变换
    - 示例：并行化，向量化与循环展开
    - 3.变换批量矩阵乘法程序
    - 构建与评估

## 如何编写

In [11]:
#=======================================
print("===================numpy实现====================")
# 首先用numpy实现
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
print(a)
print(b)

# numpy version
c_np = a + b
print(c_np)

#========================================
print("====================low-level numpy===================")
# 在我们直接编写 TensorIR 之前，我们应该首先将高级计算抽象（例如，ndarray + ndarray）
# 转换为低级 Python 实现（具有元素访问和操作的循环的标准）。

# 值得注意的是，输出数组（或缓冲区）的初始值并不总是 0。
# 我们需要在我们的实现中编写或初始化它，这对于归约运算符（例如 matmul 和 conv）很重要。
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
print(c_lnumpy)


#========================================
print("====================TensorIR===================")
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)


[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[[16 15 14 13]
 [12 11 10  9]
 [ 8  7  6  5]
 [ 4  3  2  1]]
[[16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]]
[[16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]]


In [19]:
# 练习 1：广播加法
#=======================================
print("===================numpy实现====================")
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
print(a)
print(b)
# numpy version
c_np = a + b
print(c_np)

#=======================================
print("===================TensorIR====================")

@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(
    A: T.Buffer[(4, 4), "int64"],
    B: T.Buffer[(4), "int64"],
    C: T.Buffer[(4, 4), "int64"]
  ):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
print(c_tvm.numpy())
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)




[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[4 3 2 1]
[[ 4  4  4  4]
 [ 8  8  8  8]
 [12 12 12 12]
 [16 16 16 16]]
[[ 4  4  4  4]
 [ 8  8  8  8]
 [12 12 12 12]
 [16 16 16 16]]


In [25]:
# 练习 2：二维卷积
#=======================================
print("===================torch实现====================")
stride = 1
pad = 0
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = (H - K + pad*2)//stride + 1, (W - K + pad*2)//stride + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
# print(data)
# print(weight)
# torch version
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
print(conv_torch)


#=======================================
print("===================TensorIR====================")

@tvm.script.ir_module
class MyConv:
  
  @T.prim_func
  def conv(
    data: T.Buffer[(1, 1, 8, 8), "int64"],
    weight: T.Buffer[(2, 1, 3, 3), "int64"],
    output: T.Buffer[(1, 2, 6, 6), "int64"]
  ):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    # TODO
    for b, k, i, j, di, dj, q in T.grid(N, CO, OUT_H, OUT_W, K, K, CI):
      with T.block("output"):
        vdi = T.axis.reduce(3, di)
        vdj = T.axis.reduce(3, dj)
        vq = T.axis.reduce(1, q)
        vb = T.axis.spatial(1, b)
        vk = T.axis.spatial(2, k)
        vi = T.axis.spatial(6, i)
        vj = T.axis.spatial(6, j)
        with T.init():
          output[vb, vk, vi, vj] = T.int64(0)
        output[vb, vk, vi, vj] = output[vb, vk, vi, vj] + data[vb, vq, vi+vdi, vj+vdj] * weight[vk,vq,vdi,vdj]

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]


### 经验

这是使用 NCHW 布局的卷积的数学定义：
$$Conv[b, k, i, j] =
    \sum_{di, dj, q} A[b, q, strides * i + di, strides * j + dj] * W[k, q, di, dj],$$
其中，`A` 是输入张量，`W` 是权重张量，`b` 是批次索引，`k` 是输出通道，`i` 和 `j` 是图像高度和宽度的索引，`di` 和 `dj` 是权重的索引，`q` 是输入通道，`strides` 是过滤器窗口的步幅。

- 1.我们的目标是得到结果矩阵，根据结果矩阵的维度得到空间轴(b, k, i, j)，根据计算过程中的循环嵌套得出规约轴(di, dj, q)。根据数学公式带入即可
- 2.不要忘记初始化！！

## 如何变换

In [26]:
# 示例：并行化、向量化与循环展开
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
import IPython
IPython.display.Code(sch.mod.script(), language="python")

In [29]:
# 练习 3：变换批量矩阵乘法程序
#=======================================
print("===================low-level numpy实现====================")
dtype = "float32"
a_np = np.random.rand(16, 128, 128).astype(dtype)
b_np = np.random.rand(16, 128, 128).astype(dtype)

def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

c_np = np.empty((16, 128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)



In [56]:
#=======================================
print("===================TensorIR实现====================")
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(
    A: T.Buffer[(16,128,128), "float32"],
    B: T.Buffer[(16,128,128), "float32"],
    C: T.Buffer[(16,128,128), "float32"]
  ):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    # TODO
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16,128,128,128):
        with T.block("Y"):
            vn = T.axis.spatial(16, n)
            vi = T.axis.spatial(128, i)
            vj = T.axis.spatial(128, j)
            vk = T.axis.reduce(128, k)
            with T.init():
                Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
    for n, i, j in T.grid(16,128,128):
        with T.block("C"):
            vn = T.axis.spatial(16, n)
            vi = T.axis.spatial(128, i)
            vj = T.axis.spatial(128, j)
            C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

sch = tvm.tir.Schedule(MyBmmRelu)
print(sch.mod.script())
# Also please validate your result
a_tvm = tvm.nd.array(a_np)
b_tvm = tvm.nd.array(b_np)
c_tvm = tvm.nd.array(np.empty((16, 128, 128), dtype=np.float32))
rt_lib = tvm.build(MyBmmRelu, target="llvm")
rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

# sch = tvm.tir.Schedule(MyBmmRelu)
# IPython.display.Code(sch.mod.script(), language="python")

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with T.block("root")
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                T.reads(A[vn, vi, vk], B[vn, vk, vj])
                T.writes(Y[vn, vi, vj])
                with T.init():
                    Y[vn, vi, vj] = T.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                T.re

In [58]:
# 转换目标输出
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [57]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
...

# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)
sch.parallel(b)
print(sch.mod.script())

# Step 3. Organize the loops
j0, j1 = sch.split(j, factors=[None, 8])
sch.reorder(j0,k, j1)
print(sch.mod.script())
block_C = sch.get_block("C", func_name="bmm_relu")
sch.reverse_compute_at(block_C, j0)
print(sch.mod.script())

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k)
n, i, j_0, j_1_init = sch.get_loops(Y_init)
_, _, _, ax0 = sch.get_loops(block_C)
Y_update_block = sch.get_block("Y_update", func_name="bmm_relu")
_, _, _, k, j_1 = sch.get_loops(Y_update_block)
k0, k1 = sch.split(k, factors=[32, 4])


# # Step 5. vectorize / parallel / unroll
sch.vectorize(j_1_init)
sch.vectorize(ax0)
sch.unroll(k1)

print(sch.mod.script())

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with T.block("root")
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for n in T.parallel(16):
            for i, j, k in T.grid(128, 128, 128):
                with T.block("Y"):
                    vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                    T.reads(A[vn, vi, vk], B[vn, vk, vj])
                    T.writes(Y[vn, vi, vj])
                    with T.init():
                        Y[vn, vi, vj] = T.float32(0)
                    Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi

In [59]:
# 验证输出相同
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


### 评估性能

In [60]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  47.4496      47.4496      47.4496      47.4496       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   3.9868       3.9868       3.9868       3.9868       0.0000   
               
