In [1]:
%matplotlib inline

In [5]:
import numpy as np
import tvm
from tvm import te

# The sizes of inputs and filters
batch =1
img_size = 227
img_ch = 3
filter_size = 3
filter_num = 64

# Algorithm
In = te.placeholder((batch, img_ch, img_size, img_size), name="In")
W1 = te.placeholder((filter_num, img_ch, filter_size, filter_size), name="W1")
W2 = te.placeholder((filter_num, filter_num, filter_size, filter_size), name="W2")

print(In.shape)
print(W1.shape)
print(W2.shape)

[1, 3, 227, 227]
[64, 3, 3, 3]
[64, 64, 3, 3]


In [6]:
def Conv(Input, Kernel, stride, pad, name):
    batch_size, in_channel, in_size, _ = Input.shape
    out_channel,  _, kernel_size, _ = Kernel.shape

    out_size = (in_size - kernel_size + 2 * pad) // stride + 1

    if pad > 0:
        Input = te.compute(
            (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch_size),
            lambda yy, xx, cc, nn: tvm.tir.if_then_else(
                tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size),
                Input[yy - pad, xx - pad, cc, nn],
                tvm.tir.const(0.0, "float32"),),
            name=name + "_pad",)

    rc = te.reduce_axis((0, in_channel), name=name+"_rc")
    ry = te.reduce_axis((0, kernel_size), name=name+"_ry")
    rx = te.reduce_axis((0, kernel_size), name=name+"_rx")
    # Compute the convolution
    Conv = te.compute(
        (batch_size, out_channel, out_size, out_size),
        lambda nn, ff, yy, xx: te.sum(
            Input[nn, rc, yy * stride + ry, xx * stride + rx] * Kernel[ff, rc, ry, rx], axis=[ry, rx, rc]
        ),
        name=name+'_conv',
    )
    return Conv

In [7]:
conv1 = Conv(In, W1, stride=2, pad=0, name='l1')
conv2 = Conv(conv1, W2, stride=2, pad=0, name='l2')

In [18]:
s = te.create_schedule(conv2.op)
bn = 7
# s[conv1].compute_at(s[conv2], conv2.op.axis[2])
mo, no, mi, ni = s[conv2].tile(conv2.op.axis[1], conv2.op.axis[2], bn, bn)
# s[conv2].parallel(mo)

ry, rx, rc = s[conv1].op.reduce_axis
ryxc = s[conv1].fuse(ry,rx,rc)
s[conv1].unroll(ryxc)

ry, rx, rc = s[conv2].op.reduce_axis
ryxc = s[conv2].fuse(ry,rx,rc)
s[conv2].unroll(ryxc)


# ochw = s[conv2].fuse(mi,ni)
# s[conv2].unroll(no)
# s[conv2].vectorize(ochw)
# # ry,rx, rc = s[conv2].op.reduce_axis
# ko, ki = s[conv2].split(kaxis, factor=kfactor)
# print(tvm.lower(s, [In, W1, W2, conv2], simple_mode=True))

func = tvm.build(s, [In, W1, W2, conv2], "llvm", name='conv_normal')
dev = tvm.cpu()
a_np = np.random.uniform(size=[batch,img_ch, img_size,img_size]).astype(In.dtype)
w1_np = np.random.uniform(size=[filter_num,img_ch,3,3]).astype(W1.dtype)
w2_np = np.random.uniform(size=[filter_num,filter_num,3,3]).astype(W2.dtype)
a = tvm.nd.array(a_np, dev)
w1 = tvm.nd.array(w1_np, dev)
w2 = tvm.nd.array(w2_np, dev)
b = tvm.nd.array(np.zeros((1, filter_num, 56, 56), dtype='float32'), dev)
func(a, w1, w2, b)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Convolution: %f ms" % (evaluator(a, w1, w2, b).mean * 1e3))

Convolution: 132.171369 ms


In [20]:
s = te.create_schedule(conv2.op)

opt_level = 3
target = 'llvm'
with tvm.transform.PassContext(opt_level=opt_level):
	func = tvm.build(s, [In, W1, W2, conv2], "llvm", name='conv_normal')
dev = tvm.cpu()
a_np = np.random.uniform(size=[batch,img_ch, img_size,img_size]).astype(In.dtype)
w1_np = np.random.uniform(size=[filter_num,img_ch,3,3]).astype(W1.dtype)
w2_np = np.random.uniform(size=[filter_num,filter_num,3,3]).astype(W2.dtype)
a = tvm.nd.array(a_np, dev)
w1 = tvm.nd.array(w1_np, dev)
w2 = tvm.nd.array(w2_np, dev)
b = tvm.nd.array(np.zeros((1, filter_num, 56, 56), dtype='float32'), dev)
func(a, w1, w2, b)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Convolution: %f ms" % (evaluator(a, w1, w2, b).mean * 1e3))

Convolution: 131.497636 ms
