In [1]:
import tvm
from tvm import relay, relax
import numpy as np
from tvm.script import tir as T
from tvm.script import relax as R
import torch

In [2]:
@tvm.script.ir_module
class MyConv:
    @T.prim_func
    def im2col(X: T.Buffer((1, 1024, 15, 15), "float32"),
               Y: T.Buffer((9*1024, 13*13), "float32")):
        for i,j in T.grid(9*1024, 13*13):
            with T.block("Y"):
                vi, vj = T.axis.remap("SS", (i,j))
                n = vi // 9
                ki = vj//13
                kj = vj%13
                di = (vi%9)//3
                dj = vi%3
                Y[vi, vj] = X[0,n, ki+di, kj+dj]


    @R.function
    def forward(X: R.Tensor((1, 1024, 15, 15), "float32"),
                W: R.Tensor((1024, 1024, 3, 3), "float32")):
        cls = MyConv
        with R.dataflow():
            # transposed_x = relax.op.permute_dims(X, [0,2,3,1])
            transposed_w = relax.op.permute_dims(W, [0,2,3,1])
            good_w = relax.op.reshape(transposed_w, (1024, 1024*9))
            lv0 = R.call_tir(cls.im2col, (X, ), out_sinfo=R.Tensor((9*1024,13*13), dtype="float32"))
            lv1 = relax.op.matmul(good_w, lv0)
            lv2 = relax.op.reshape(lv1, (1, 1024,13,13))
            R.output(lv2)
        return lv2
    
mod = relax.transform.LegalizeOps()(MyConv)
mod.show()

In [3]:
ex = relax.build(MyConv, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True)

In [4]:
x = np.random.rand(1,1024,15,15).astype("float32")
tvm_x = tvm.nd.array(x)
w = np.random.rand(1024,1024,3,3).astype("float32")
tvm_w = tvm.nd.array(w)

x_ = torch.Tensor(x)
w_ = torch.Tensor(w)

ref = torch.nn.functional.conv2d(x_, w_).numpy()

res = vm["forward"](tvm_x, tvm_w).numpy()

np.testing.assert_allclose(res, ref, atol=1e-4, rtol=1)

In [5]:
evaluator = vm.profile("forward",
    tvm_x, tvm_w
)
evaluator

Name                          Duration (us)  Percent  Device  Count                                              Argument Shapes  
matmul                         5 810 714,64    99.72    cpu0      1  float32[1024, 9216], float32[9216, 169], float32[1024, 169]  
transpose                         13 943,87     0.24    cpu0      1         float32[1024, 1024, 3, 3], float32[1024, 3, 3, 1024]  
im2col                             1 818,15     0.03    cpu0      1                 float32[1, 1024, 15, 15], float32[9216, 169]  
reshape1                              72,88     0.00    cpu0      1                 float32[1024, 169], float32[1, 1024, 13, 13]  
vm.builtin.check_tensor_info           5,97     0.00    cpu0      1                                     float32[1, 1024, 15, 15]  
vm.builtin.reshape                     5,77     0.00    cpu0      1                                    float32[1024, 3, 3, 1024]  
vm.builtin.match_shape                 3,09     0.00    cpu0      1                

In [6]:
sch=tvm.tir.Schedule(mod)
matmul = sch.get_block("matmul", func_name="matmul")
i,j,k = sch.get_loops(matmul)
 
#i0, i1 = sch.split(i, factors = [None, 8])
#j0, j1 = sch.split(j, factors = [None, 8])
#k0, k1 = sch.split(k, factors = [None, 8])

sch.reorder(i,k,j)
#sch.unroll(j)
#sch.vectorize(j)

#sch.mod.show()

ex_ = relax.build(sch.mod, target="llvm -mcpu=core-avx2")
vm_ = relax.VirtualMachine(ex_, tvm.cpu(), profile=True)
evaluator = vm_.profile("forward",
    tvm_x, tvm_w
)

evaluator

Name                          Duration (us)  Percent  Device  Count                                              Argument Shapes  
matmul                           570 620,60    97.31    cpu0      1  float32[1024, 9216], float32[9216, 169], float32[1024, 169]  
transpose                         13 097,18     2.23    cpu0      1         float32[1024, 1024, 3, 3], float32[1024, 3, 3, 1024]  
im2col                             2 328,63     0.40    cpu0      1                 float32[1, 1024, 15, 15], float32[9216, 169]  
reshape1                              77,67     0.01    cpu0      1                 float32[1024, 169], float32[1, 1024, 13, 13]  
vm.builtin.reshape                     7,36     0.00    cpu0      1                                    float32[1024, 3, 3, 1024]  
vm.builtin.check_tensor_info           3,54     0.00    cpu0      1                                     float32[1, 1024, 15, 15]  
vm.builtin.match_shape                 3,28     0.00    cpu0      1                