In [1]:
import tvm
import tvm.testing
from tvm import te
import numpy
import timeit

In [3]:
# The size of the matrix
# (M, K) x (K, N)
# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
M = 1024
K = 1024
N = 1024

# The default tensor type in tvm
dtype = "float32"

# using Intel AVX2(Advanced Vector Extensions) ISA for SIMD
# To get the best performance, please change the following line
# to llvm -mcpu=core-avx2, or specific type of CPU you use
target = "llvm"
dev = tvm.device(target, 0)

# Random generated tensor for testing
a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), dev)

In [5]:
# 学习如何用numpy来解决这个问题。
np_repeat = 100
np_runing_time = timeit.timeit(
    setup="import numpy\n"
    "M = " + str(M) + "\n"
    "K = " + str(K) + "\n"
    "N = " + str(N) + "\n"
    'dtype = "float32"\n'
    "a = numpy.random.rand(M, K).astype(dtype)\n"
    "b = numpy.random.rand(K, N).astype(dtype)\n",
    stmt="answer = numpy.dot(a, b)",
    number=np_repeat,
)
print("Numpy running time: %f" % (np_runing_time / np_repeat))

answer = numpy.dot(a.numpy(), b.numpy())

Numpy running time: 0.001303


In [7]:
# Algorithm
# 标准的tvm的写法
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")

# Default schedule
s = te.create_schedule(C.op)
func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Baseline: %f" % evaluator(a, b, c).mean)

print(tvm.lower(s, [A, B, C], simple_mode=True))

Baseline: 3.455440
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m, n in T.grid(1024, 1024):
            C_1 = T.Buffer((1048576,), data=C.data)
            C_1[m * 1024 + n] = T.float32(0.0)
            for k in range(1024):
                cse_var_2: T.int32 = m * 1024
                cse_var_1: T.int32 = cse_var_2 + n
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[cse_var_2 + k] * B_1[k * 1024 + n]


In [9]:
bn = 32
kfactor = 4
s = te.create_schedule(C.op)

In [11]:
# Blocking by loop tiling
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

In [13]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer, m_inner, n_inner in T.grid(32, 32, 32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            C_1[m_outer * 32768 + m_inner * 1024 + n_outer * 32 + n_inner] = T.float32(0.0)
            for k_outer, k_inner in T.grid(256, 4):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3 + n_inner
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[cse_var_2 + k_outer * 4 + k_inner] * B_1[k_

In [15]:
# Hoist reduction domain outside the blocking loop
s[C].reorder(mo, no, ko, ki, mi, ni)

# 一个cache是32所以要reorder

In [17]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init, n_inner_init in T.grid(32, 32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + n_inner_init] = T.float32(0.0)
            for k_outer, k_inner, m_inner, n_inner in T.grid(256, 4, 32, 32):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3 + n_inner
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[c

In [19]:
func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

# By simply tiling the loop 32x32, and hoisting ko, ki outside the blocking loops,
# we can see big speedup compared with the baseline.
evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt1: %f" % evaluator(a, b, c).mean)

Opt1: 0.208868


In [21]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init, n_inner_init in T.grid(32, 32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + n_inner_init] = T.float32(0.0)
            for k_outer, k_inner, m_inner, n_inner in T.grid(256, 4, 32, 32):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3 + n_inner
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[c

In [31]:
s = te.create_schedule(C.op)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

s[C].reorder(mo, no, ko, ki, mi, ni)

# Vectorization
s[C].vectorize(ni)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt2: %f" % evaluator(a, b, c).mean)

Opt2: 0.223579


In [33]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init in range(32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32:m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + 32] = T.Broadcast(T.float32(0.0), 32)
            for k_outer, k_inner, m_inner in T.grid(256, 4, 32):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
   

In [35]:
s = te.create_schedule(C.op)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

# re-ordering
s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt3: %f" % evaluator(a, b, c).mean)

Opt3: 0.111549


In [37]:
# We have to re-write the algorithm slightly.
# We have to re-write the algorithm slightly.
packedB = te.compute(
    (N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB"
)
C = te.compute(
    (M, N),
    lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
    name="C",
)

s = te.create_schedule(C.op)

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt4: %f" % evaluator(a, b, c).mean)

s = te.create_schedule(C.op)

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt4: %f" % evaluator(a, b, c).mean)

Opt4: 0.124942


In [39]:
s = te.create_schedule(C.op)

# Allocate write cache
CC = s.cache_write(C, "global")

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

# Write cache is computed at no
s[CC].compute_at(s[C], no)

# New inner axes
mc, nc = s[CC].op.axis

(kaxis,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(kaxis, factor=kfactor)
s[CC].reorder(ko, mc, ki, nc)
s[CC].vectorize(nc)

# TODO: Add separate optimization step to discuss loop unrolling
# unrolling is a loop optimization strategy which can reduce branch
# prediction failures and increases the chance of concurrent execution
# unroll kfactor loops
s[CC].unroll(ki)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=10)
print("Opt5: %f" % evaluator(a, b, c).mean)

Opt5: 0.117523


In [41]:
s = te.create_schedule(C.op)

CC = s.cache_write(C, "global")

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

s[CC].compute_at(s[C], no)

mc, nc = s[CC].op.axis

(kaxis,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(kaxis, factor=kfactor)
s[CC].reorder(ko, mc, ki, nc)
s[CC].vectorize(nc)
s[CC].unroll(ki)

# parallel
s[C].parallel(mo)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func

c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

evaluator = func.time_evaluator(func.entry_name, dev, number=50)
opt6_time = evaluator(a, b, c).mean
print("Opt6: %f" % opt6_time)

Opt6: 0.005964
