In [1]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
import IPython

In [5]:
a=np.arange(16).reshape(4,4)
b=np.arange(16, 0, -1).reshape(4,4)

c_np=a+b
print(c_np)

[[16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]]


In [7]:
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]

c_lnp=np.zeros((4,4), dtype=np.int64)
lnumpy_add(a, b, c_lnp)

@tvm.script.ir_module
class Tadd:
    @T.prim_func
    def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
        T.func_attr({"global_symbol": "add"})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.spatial(4, i)
                vj = T.axis.spatial(4, j)
                C[vi, vj] = A[vi, vj] + B[vi, vj]

tadd = tvm.build(Tadd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.zeros((4, 4), dtype="int64"))
tadd["add"](a_tvm, b_tvm, c_tvm)

np.testing.assert_allclose(c_tvm.numpy(), c_lnp, rtol=1e-5)

### 2.5.1.2. 练习 1：广播加法

In [10]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
c_np = a + b

@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
            B:T.Buffer[4, "int64"],
            C:T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    # mydo
    for i, j in T.grid(4, 4):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i,j])
            C[vi, vj] = A[vi, vj] + B[vj]



rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

### 2.5.1.3. 练习 2：二维卷积

In [11]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

  from .autonotebook import tqdm as notebook_tqdm


array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [21]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(Data: T.Buffer[(N, CI, H, W), "int64"],
          Weight:T.Buffer[(CO, CI, K, K), "int64"],
          Out:T.Buffer[(N, CO, OUT_H, OUT_W), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    # MYDO
    for n, co, h, w in T.grid(N, CO, OUT_H, OUT_W):
        with T.block("outer"):
            vn, vco, vh, vw = T.axis.remap("SSSS", [n, co, h, w])
            with T.init():
                Out[vn, vco, vh, vw] = 0
            for ci, kh, kw in T.grid(CI, K, K):
                with T.block("inner"):
                    # vci, vkh, vkw = T.axis.remap("SSS", [ci, kh, kw])
                    vci = T.axis.reduce(CI, ci)
                    vkh = T.axis.reduce(K, kh)
                    vkw = T.axis.reduce(K, kw)
                    Out[vn, vco, vh, vw] = Out[vn, vco, vh, vw] + Data[vn, vci, vh+kh, vw+kw] * Weight[vco, vci, vkh, vkw]


rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)
print("ok.")

ok.


## 2.5.2. 第二节：如何变换 TensorIR

### 2.5.2.1. 并行化、向量化与循环展开

In [24]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.Code(sch.mod.script(), language="python")


In [26]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)


@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A:T.Buffer[(16, 128, 128), "float32"],
               B:T.Buffer[(16, 128, 128), "float32"],
               C:T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    # mydo
    Y=T.alloc_buffer((16, 128, 128), "float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
            with T.init():
                Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] *  B[vn, vk, vj]
    for n, i, j in T.grid(16, 128, 128):
        with T.block("C"):
            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
            C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

            

sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

In [61]:
"""target ir:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))
"""

sch = tvm.tir.Schedule(MyBmmRelu)
# MYDO: transformations

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")

# Step 2. Get loops
by, iy, jy, ky = sch.get_loops(Y)
sch.parallel(by)

# Step 3. Organize the loops
jy0, jy1 = sch.split(jy, [16, 8])
ky0, ky1 = sch.split(ky, [32, 4])

sch.reorder(by, iy, jy0, ky0, ky1, jy1)
sch.reverse_compute_at(C, jy0)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, ky0) # insert before loopV


# Step 5. vectorize / parallel / unroll
sch.unroll(ky1)
_,_,_,ax0_init = sch.get_loops(sch.get_block("Y_init", func_name="bmm_relu"))
_,_,_,i2_1 = sch.get_loops(C)
sch.vectorize(ax0_init)
sch.vectorize(i2_1)



IPython.display.Code(sch.mod.script(), language="python")
# tvm.ir.assert_structural_equal(sch.mod, MyBmmRelu)

In [62]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a=np.random.rand(16, 128, 128).astype("float32")
b=np.random.rand(16, 128, 128).astype("float32")
c=np.random.rand(16, 128, 128).astype("float32")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(c)
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
lnumpy_mm_relu_v2(a,b,c)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
np.testing.assert_allclose(c_tvm.numpy(), c, rtol=1e-5)

print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  33.7151      33.7151      33.7151      33.7151       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   2.9314       2.9314       2.9314       2.9314       0.0000   
               
