"""Example code to do square matrix multiplication."""
import tvm
import os
from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
write_code(code, "perf/" % TASK)
code = open("perf/" % TASK).read()
return code
def test_gemm():
# graph
nn = 2048
n = tvm.var('n')
n = tvm.convert(nn)
m, l = n, n
A = tvm.placeholder((l, n), name='A')
B = tvm.placeholder((l, m), name='B')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute(
(m, n),
lambda ii, jj: tvm.sum(A[k, jj] * B[k, ii], axis=k),
# schedule
s = tvm.create_schedule(C.op)
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy")
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
s[C].reorder(by, bx, yi, xi)
tyz, yi = s[C].split(yi, nparts=2)
ty, yi = s[C].split(yi, nparts=num_thread)
txz, xi = s[C].split(xi, nparts=2)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].bind(tyz, thread_yz)
s[C].bind(txz, thread_xz)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].reorder(tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)
yo, xo = CC.op.axis
ko, ki = s[CC].split(k, factor=8)
kt, ki = s[CC].split(ki, factor=1)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)
# Schedule for A's shared memory load
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
# Schedule for B' shared memory load
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
# correctness
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
print("Device %s" % device)
f =, [A, B, C], device)
# launch the kernel.
n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
for i in range(2):
f(a, b, c)
c.asnumpy(),, a_np), rtol=1e-5)
num_flops = 2 * nn * nn * nn
num_runs = 10
timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
t = timer_f(a, b, c).mean
GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
with tvm.build_config(auto_unroll_max_step=128,
unroll_explicit=(device != "cuda")):
if __name__ == "__main__":