The utility functions

In [7]:
import scipy.sparse
import numpy as np
import tvm
from tvm.rpc import RPCSession

def make_bsr_sparse(dense, sprate, blocksize):
    bsrdata = scipy.sparse.bsr_matrix(dense, blocksize=blocksize)
    # find partition value
    summed = bsrdata.data.sum((1, 2))
    idx = int(sprate * len(summed) + 0.5)
    val = np.partition(summed, idx)[idx]
    # filter the data
    data, indices, indptr, bsrWid = [], [], [], bsrdata.indptr[1]
    for idx, (block, indval) in enumerate(zip(bsrdata.data, bsrdata.indices)):
        if idx % bsrWid == 0:
            indptr.append(len(data))
        if block.sum() >= val:
            data.append(block)
            indices.append(indval)
    indptr.append(len(data))
    # convert format
    bsrdata2 = tuple([np.array(i) for i in [data, indices, indptr]])
    return scipy.sparse.bsr_matrix(bsrdata2, shape=dense.shape)


def unpack_bsr(bsrdata):
    return bsrdata.data, bsrdata.indices, bsrdata.indptr


def hook_method(obj, attr):
    def real_decorator(func):
        orig = getattr(obj, attr)
        setattr(obj, attr, func)
        func.orig = orig
        func.revert = lambda: setattr(obj, attr, orig)
        return func
    return real_decorator


class NonRandomFill:
    srclst_ = []
    
    @classmethod
    def set_srclst(cls, srclst):
        cls.srclst_ = [tvm.nd.array(it) for it in srclst]

    def __init__(self):
        self.srclst = iter(self.srclst_)
    
    def __call__(self, tgt):
        src = next(self.srclst)
        tgt.copyfrom(src)


@hook_method(RPCSession, 'get_function')
def new_get_function(self, fname):
    if fname == 'tvm.contrib.random.random_fill':
        return NonRandomFill()
    else:
        return new_get_function.orig(self, fname)

In [None]:
import tvm
import numpy as np
from tvm import autotvm, te, tir
from functools import partial, reduce
N, H, W, CI = 1, 28, 28, 64
CO = 64
Y, X, K = N*H*W, CO, 9*CI
sprate = 0.9
nhwc_data = np.random.randint(0, 256, (N, H, W, CI)).astype('float32')
weight_ohwi = np.random.rand(CO, 3*3*CI).astype('float32')
spweight_ohwi = make_bsr_sparse(weight_ohwi, sprate, (4, 1))  
ret = np.zeros((N*H*W, CO), dtype='float32')


nElems, bsrR, bsrC = spweight_ohwi.data.shape
args = (N, H, W, CI, CO, *spweight_ohwi.data.shape, 'float32')

"""
print("args=",args)
print("nhwc_data=",nhwc_data)
#print("spweight_ohwi.data=",spweight_ohwi.data)
print(spweight_ohwi.nnz)
print(spweight_ohwi.data)
print(spweight_ohwi.indices)
print(spweight_ohwi.indptr)
"""

In [None]:

#def spconv2d_3x3_gemm(N, H, W, CI, CO, nElems, bsrR, bsrC, dtype='float32'):
dtype = 'float32'
Y, X, K = N*H*W, CO, 9*CI
cfg = autotvm.get_config()
cfg.define_split("tile_y", Y, num_outputs=3)
cfg.define_split("tile_x", X // bsrR, num_outputs=2)
cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X))
#cfg.define_split("tile_k", K, num_outputs=2)
if cfg.is_fallback:
    cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 160, 8])
    cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 4])

Data = te.placeholder((N, H, W, CI), dtype=dtype, name='Data')
Wdat = te.placeholder((nElems, bsrR, bsrC), name='Wdat')
Wind = te.placeholder((nElems,), dtype='int', name='Wind')
Wptr = te.placeholder((X // bsrR + 1,), dtype='int', name='Wptr')
idxsplit = lambda x,y: reduce(lambda a,b: a[:-1]+[a[-1]%b,a[-1]//b], y, [x])

@partial(te.compute, (Y, K), name='Im2Col')
def Im2Col(row, col):
    jw, jh, jn = idxsplit(row, [W, H])
    jc, kw, kh = idxsplit(col, [CI, 3])
    ih, iw = jh + kh - 1, jw + kw - 1
    return tir.if_then_else(
        tir.all(0 <= ih, ih < H, 0 <= iw, iw < W),
        Data[jn, ih, iw, jc], 0)

@partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name='CC')
def CC(drow, wrow, brow, bcol):
    row_start, row_end = Wptr[wrow], Wptr[wrow+1]
    elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx')
    elem = row_start + elem_idx
    return te.sum(Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx)

k = te.reduce_axis((0, bsrC), name='k')
C = te.compute((Y, X), lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k), name='C')

s = te.create_schedule(C.op)
y, x = s[C].op.axis
yt, yo, yi = cfg['tile_y'].apply(s, C, y)
xo, xi = s[C].split(x, factor=bsrR)
xt, xo = cfg['tile_x'].apply(s, C, xo)
(k,) = s[C].op.reduce_axis
s[C].reorder(yt, xt, yo, xo, yi, xi, k)
s[C].unroll(k)
s[C].vectorize(xi)
s[C].unroll(yi)

s[CC].compute_at(s[C], xo)
yi, xi, r, c = s[CC].op.axis
(k,) = s[CC].op.reduce_axis
s[CC].reorder(xi, k, yi, r, c)
s[CC].unroll(c)
s[CC].vectorize(r)
s[CC].unroll(yi)

s[Im2Col].compute_at(s[C], yo)
yi, k = s[Im2Col].op.axis
ko, ki = s[Im2Col].split(k, factor=CI)
s[Im2Col].vectorize(ki)
#s[Im2Col].unroll(yi)
#return s, [Data, Wdat, Wind, Wptr, C]

In [None]:
#print(tvm.lower(s, [Data, Wdat, Wind, Wptr, C], simple_mode=True))

In [None]:
func = tvm.build(s, [Data, Wdat, Wind, Wptr, C])

In [None]:
output_placeholder = tvm.nd.array(np.zeros((Y,CO)).astype('float32'))
args = (tvm.nd.array(nhwc_data), 
        tvm.nd.array(spweight_ohwi.data), 
        tvm.nd.array(spweight_ohwi.indices), 
        tvm.nd.array(spweight_ohwi.indptr), 
        output_placeholder)

In [7]:
func(*args)

In [8]:
#print(output_placeholder)

In [9]:
tgtstr = "llvm -mcpu=skylake"
dev = tvm.device(tgtstr, 0)
evt = func.time_evaluator(func.entry_name, dev, number=3)
print(evt(*args).mean)

0.0005105583333333333
