In [1]:
!pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.12.dev785%2Bgfc09f562f-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.12.dev785+gfc09f562f


## 1.如何编写TensorIR

In [2]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T 

In [3]:
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
print(a)
print(b)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[[16 15 14 13]
 [12 11 10  9]
 [ 8  7  6  5]
 [ 4  3  2  1]]


In [4]:
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

先将高级计算抽象转化为低级Python实现

In [5]:
def lnumpy(a: np.ndarray, b: np.ndarray, c:np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

将低级别的numpy 转化为TensorIR

In [6]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B : T.Buffer((4, 4), "int64"),
          C : T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

## 1.1 广播加法

In [7]:
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
print(a)
print(b)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[4 3 2 1]


In [8]:
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

先转化为低级numpy实现

In [9]:
def lnumpy_boargcast_add(a: np.ndarray, b: np.ndarray, c:np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_boargcast_add(a, b, c_lnumpy)
c_lnumpy

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

用TensorIR实现

In [10]:
@tvm.script.ir_module
class MyAddBoardCast:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]
rt_lib = tvm.build(MyAddBoardCast, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

## 1.2 二维卷积

In [11]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N,CI,H,W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [12]:
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch.shape


(1, 2, 6, 6)

In [13]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(DATA_TVM: T.Buffer((N, CI, H, W), "int64"),
        WEIGHT_TVM: T.Buffer((CO, CI, K, K), "int64"),
        CONV_TVM: T.Buffer((N, CO, OUT_H, OUT_W), "int64")
    ):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for b, k, i, j, di, dj in T.grid(1, 2,  6, 6, 3, 3):
      with T.block("CONV_TVM"):
        vb = T.axis.spatial(1, b)
        vk = T.axis.spatial(2, k)
        vi = T.axis.spatial(6, i)
        vj = T.axis.spatial(6, j)
        vdi = T.axis.spatial(3, di)
        vdj = T.axis.spatial(3, dj)
        CONV_TVM[vb,vk,vi,vj] += DATA_TVM[vb,0,vi+vdi,vj+vdj] * WEIGHT_TVM[vk,0,vdi,vdj]    
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.zeros((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

# 2.如何变换TensorIR

In [14]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4, 4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.Code(sch.mod.script(), language="python")

变换批量矩阵乘法
bmm_relu 的numpy写法为

In [15]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
  Y = np.empty((16, 288, 288), dtype="float32")
  for n in range(16):
    for i in range(128):
      for j in range(128):
        for k in range(128):
          if k == 0:
            Y[n, i, j] = 0
          Y[n, i, j] = Y[n, i, j] + A[n, i, j] * B[n, k, j]
  for n in range(16):
    for i in range(128):
      for j in range(128):
        C[n, i, j] = max(Y[n, i, j], 0)
        

In [24]:
@tvm.script.ir_module
class MyBmmRelu:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for b, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vb, vi, vj, vk = T.axis.remap("SSSR", [b, i, j, k])
                with T.init():
                    Y[vb, vi, vj] = T.float32(0)
                Y[vb, vi, vj] = Y[vb, vi, vj] + A[vb, vi, vk] * B[vb, vk, vj]
        for b, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vb, vi, vj = T.axis.remap("SSS", [b, i, j])
                C[vb, vi, vj] = T.max(Y[vb, vi, vj], T.float32(0))

  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:


下面是目标程序，需要将原始程序转化为目标程序

In [25]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), B: T.Buffer((16, 128, 128), "float32"), C: T.Buffer((16, 128, 128), "float32")) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [26]:
sch = tvm.tir.Schedule(MyBmmRelu)

block_Y = sch.get_block("Y", func_name="bmm_relu")
block_C = sch.get_block("C", "bmm_relu")

b, i, j, k = sch.get_loops(block_Y)
i2_0, ax0 = sch.split(j, [16, None])
ax1_0, ax1_1 = sch.split(k, [None, 4])
sch.reorder(ax1_0, ax1_1, ax0)
sch.parallel(b)
sch.unroll(ax1_1)
sch.reverse_compute_at(block_C, i2_0)
sch.decompose_reduction(block_Y, ax1_0)

i0, i1, i2_0, i2_1 = sch.get_loops(block_C)
sch.vectorize(i2_1)

block_Y_init = sch.get_block("Y_init", "bmm_relu")

b, i, j0, j_1_init = sch.get_loops(block_Y_init)
sch.vectorize(j_1_init)

In [27]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


In [29]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")

In [30]:
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  75.5335      75.5335      75.5335      75.5335       0.0000   
               
After transformation
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  20.1290      20.1290      20.1290      20.1290       0.0000   
               
