In [1]:
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize

In [11]:
def generate_quantized_np(shape, bits, out_dtype):
    min_val = 0
    max_val = 1 << bits
    return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)

def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, 
    activation_bits, weight_bits, dorefa):
    in_height = in_width = in_size
    input_type = 'uint32'
    out_dtype = 'int32'

    with tvm.target.create('llvm'):
        A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
        W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
        C = tvm.placeholder((4,), dtype='float32', name='C')
        B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, 
                                     out_dtype=out_dtype, layout="NCHW", dorefa=dorefa, clusters=C)
        s = topi.generic.schedule_bitserial_conv2d_nchw([B])

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)

    @memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
    def get_ref_data():
        a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
        w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
        if dorefa:
            w_ = np.copy(w_np).astype(out_dtype)
            for x in np.nditer(w_, op_flags=['readwrite']):
                x[...] = 1 if x == 1 else -1
            b_np = topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding)
        else:
            b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
        return a_np, w_np, b_np
    a_np, w_np, b_np = get_ref_data()
    c_np = np.array([1., 2., 3., 4.], dtype=np.float32)

    ctx = tvm.cpu(0)
    a = tvm.nd.array(a_np, ctx)
    w = tvm.nd.array(w_np, ctx)
    c = tvm.nd.array(c_np, ctx)
    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
    func = tvm.build(s, [A, W, C, B], "llvm")
    func(a, w, c, b)
    np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

In [12]:
in_size = 56
ic, oc = 64, 64
k = 3
stride = 1
pad = 1

In [13]:
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, True)

array([[[[  7.,   5.,  21., ...,  17.,  29.,  14.],
         [  6.,  14.,  24., ...,  34.,  24.,  -3.],
         [ 21.,  16.,  17., ..., -10.,   3., -19.],
         ...,
         [ 12.,  15.,  20., ...,   3.,  17.,   2.],
         [ -8.,  25.,  13., ...,  27.,  11.,   6.],
         [ 10.,   7.,   2., ...,   3.,   0.,  -5.]],

        [[ 21.,  15.,  -5., ...,  -3.,   3.,   6.],
         [ 12.,  -8.,  -8., ...,  12.,  14.,  -7.],
         [ 11.,   2., -13., ...,   8., -15., -11.],
         ...,
         [  0.,  -1.,  -2., ...,  17.,  11.,  22.],
         [ -4.,   9., -15., ...,  -9.,   5.,   8.],
         [  0.,  -7.,  16., ...,  -9., -18.,   3.]],

        [[ -9., -13.,  -1., ..., -13.,   3.,  10.],
         [-16.,   2.,  10., ..., -10., -18.,  -5.],
         [ 11., -16.,   1., ..., -20., -17., -11.],
         ...,
         [-10.,  -5.,  12., ...,   9.,  11.,  -4.],
         [-18.,  -5.,  -3., ..., -23.,  -7.,   8.],
         [ -2., -13., -24., ..., -11., -20.,  -7.]],

        ...,

  