In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

## section 1: how to write `TensorIR`

### exercise 1: broadcast add

In [2]:
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [3]:
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [4]:
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    def add(
        A: T.Buffer((4, 4), "int64"),
        B: T.Buffer(4, "int64"),
        C: T.Buffer((4, 4), "int64"),
    ):
        T.func_attr({"global_symbol": "add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.spatial(4, i)
                vj = T.axis.spatial(4, j)
                with T.init():
                    C[vi, vj] = T.int64(0)
                C[vi, vj] = A[vi, vj] + B[vj]

In [5]:
rt_lib = tvm.build(MyAdd, target="llvm")

In [6]:
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))

In [7]:
rt_lib["add"](a_tvm, b_tvm, c_tvm)
c_tvm

<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [8]:
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

### exercise 2: 2D convolution

$$ \operatorname{conv}_2[b,k,i,j] = \sum_{d_i, d_j, q} A[b,q,\operatorname{stride}\cdot i+d_i,\operatorname{stride}\cdot j+d_j]\cdot W[k,q,d_i, d_j]
$$

where
* $b$ is the batch index
* $k$ output channel index
* $i, j$ pixel location
* $d_i, d_j$ indices of weight kernel
* $q$ input channel index
* $\operatorname{stride}$ is the stride of kernel

In [9]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [10]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [11]:
@tvm.script.ir_module
class MyConv:
    @T.prim_func
    def conv(
        A: T.Buffer((N, CI, H, W), "int64"),
        D: T.Buffer((CO, CI, K, K), "int64"),
        C: T.Buffer((N, CO, OUT_H, OUT_W), "int64"),
    ):
        T.func_attr({"global_symbol": "conv", "tir.noalias": True})
        for b, k, i, j in T.grid(N, CO, H, W):
            with T.block("outer"):
                vb, vk, vi, vj = T.axis.remap("SSSS", [b, k, i, j])
                with T.init():
                    C[vb, vk, vi, vj] = T.int64(0)
                for d_i, d_j, q in T.grid(K, K, CI):
                    with T.block("inner"):
                        vd_i, vd_j, vq = T.axis.remap("RRR", [d_i, d_j, q])
                        C[vb, vk, vi, vj] = C[vb, vk, vi, vj] + A[vb, vq, vi + vd_i, vj + vd_j] * D[vk, vq, vd_i, vd_j]    

In [12]:
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)

In [13]:
conv_tvm

<tvm.nd.NDArray shape=(1, 2, 6, 6), cpu(0)>
array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [14]:
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## section 2: how to transform `TensorIR`

### exercise 3: transform a batch matmul

In [15]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [None]:
@tvm.script.ir_module
class MyBmmRelu:
    @T.prim_func
    def bmm_relu(
        A: T.Buffer((16, 128, 128), "float32"),
        B: T.Buffer((16, 128, 128), "float32"),
        C: T.Buffer((16, 128, 128), "float32"),
    ):
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((16, 128, 128), dtype="float32")
        pass