In [1]:
import tvm
from tvm import te
import numpy as np

def conv_out_size(n,k,p,s):
    return (n-k+2*p)//s + 1
def padding(X, ph, pw, val=0):
    assert len(X.shape) >= 2
    nh,nw = X.shape[-2],X.shape[-1]
    return te.compute(
        (*X.shape[0:-2],nh+ph*2,nw+pw*2),
        lambda *i:te.if_then_else(
            te.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
            val,X[i[:-2]+(i[-2]-ph,i[-1]-pw)]),name='PaddedX')

In [2]:
def pool(pool_type, c, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):
    # reduction axes
    rkh = te.reduce_axis((0,kh), name='rkh')
    rkw = te.reduce_axis((0,kw), name='rkw')
    
    oh = conv_out_size(nh, kh, ph, sh)
    ow = conv_out_size(nw, kw, pw, sw)
    
    X = te.placeholder((c, nh, nw), name='X')
    
    if pool_type == 'max':
        PaddedX = padding(X, ph, pw, val=te.min_value(X.dtype)) \
            if ph * pw != 0 else X
        Y = te.compute((c, oh, ow), \
                          lambda c, h, w:\
                          te.max(PaddedX[c, h*sh+rkh, w*sw+rkw],\
                                axis=[rkh,rkw]),\
                          tag='pool_max', name='PoolMax')
    elif pool_type == 'avg':
        PaddedX = padding(X, ph, pw) if ph * pw != 0 else X
        tsum = te.compute((c, oh, ow), \
                          lambda c, h, w:\
                          te.sum(PaddedX[c, h*sh+rkh, w*sw+rkw],\
                                axis=[rkh,rkw]),\
                          tag='pool_avg1', name='PoolSum')
        Y = te.compute((c,oh,ow),\
                          lambda c, h, w:\
                          tsum[c,h,w] / (kh*kw),\
                          tag='pool_avg2', name='PoolAvg')
    else:
        raise ValueError("Pool type should be 'avg' or 'max'.")
    return X, Y, PaddedX
        


In [3]:
def get_conv_data(oc,ic,n,k,p=0,s=1,constructor=None,conv_type='direct'):
    np.random.seed(0)
    data = np.random.normal(size=(ic,n,n)).astype('float32')
    ic_weight = ic
    if conv_type == 'depthwise':
        ic_weight = 1
    weight = np.random.normal(size=(oc,ic_weight,k,k)).astype('float32')
    on = conv_out_size(n,k,p,s)
    out = np.empty((oc,on,on), dtype='float32')
    if constructor:
        data, weight, out = (constructor(x) for x in [data, weight, out])
    return data,weight,out

In [4]:
c, n, k, p, s = 4, 12, 3, 1, 1
X, Y, PaddedX = pool('max', c, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Y])
print(tvm.lower(sch, [X, Y], simple_mode=True))
data, _, out_max = get_conv_data(c, c, n, k, p, s, tvm.nd.array)
mod(data, out_max)

// attr [PaddedX] storage_scope = "global"
allocate PaddedX[float32 * 784]
produce PaddedX {
  for (i0, 0, 4) {
    for (i1, 0, 14) {
      for (i2, 0, 14) {
        PaddedX[(((i0*196) + (i1*14)) + i2)] = tvm_if_then_else(((((i1 < 1) || (13 <= i1)) || (i2 < 1)) || (13 <= i2)), -3.40282e+38f, X[((((i0*144) + (i1*12)) + i2) - 13)])
      }
    }
  }
}
produce PoolMax {
  for (c, 0, 4) {
    for (h, 0, 12) {
      for (w, 0, 12) {
        PoolMax[(((c*144) + (h*12)) + w)] = -3.40282e+38f
        for (rkh, 0, 3) {
          for (rkw, 0, 3) {
            PoolMax[(((c*144) + (h*12)) + w)] = max(PoolMax[(((c*144) + (h*12)) + w)], PaddedX[(((((c*196) + (h*14)) + (rkh*14)) + w) + rkw)])
          }
        }
      }
    }
  }
}



In [5]:
X,Y,PaddedX = pool('avg',c,n,n,k,k,p,p,s,s)
sch = te.create_schedule(Y.op)
mod = tvm.build(sch,[X,Y])
print(tvm.lower(sch,[X,Y],simple_mode=True))
data,_,out_avg = get_conv_data(c,c,n,k,p,s,tvm.nd.array)
mod(data,out_avg)

// attr [PaddedX] storage_scope = "global"
allocate PaddedX[float32 * 784]
// attr [PoolSum] storage_scope = "global"
allocate PoolSum[float32 * 576]
produce PaddedX {
  for (i0, 0, 4) {
    for (i1, 0, 14) {
      for (i2, 0, 14) {
        PaddedX[(((i0*196) + (i1*14)) + i2)] = tvm_if_then_else(((((i1 < 1) || (13 <= i1)) || (i2 < 1)) || (13 <= i2)), 0f, X[((((i0*144) + (i1*12)) + i2) - 13)])
      }
    }
  }
}
produce PoolSum {
  for (c, 0, 4) {
    for (h, 0, 12) {
      for (w, 0, 12) {
        PoolSum[(((c*144) + (h*12)) + w)] = 0f
        for (rkh, 0, 3) {
          for (rkw, 0, 3) {
            PoolSum[(((c*144) + (h*12)) + w)] = (PoolSum[(((c*144) + (h*12)) + w)] + PaddedX[(((((c*196) + (h*14)) + (rkh*14)) + w) + rkw)])
          }
        }
      }
    }
  }
}
produce PoolAvg {
  for (c, 0, 4) {
    for (h, 0, 12) {
      for (w, 0, 12) {
        PoolAvg[(((c*144) + (h*12)) + w)] = (PoolSum[(((c*144) + (h*12)) + w)]*0.111111f)
      }
    }
  }
}



### MXNet Baseline 

In [7]:
import mxnet as mx

def get_pool_data_mxnet(c, n, k, p, s, ctx='cpu'):
    ctx = getattr(mx, ctx)()
    data, _, out = get_conv_data(c, c, n, k, p, s,
                                      lambda x: mx.nd.array(x, ctx=ctx))
    data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)
    return data, out

def pool_mxnet(pool_type, data, out, k, p, s):
    mx.nd.Pooling(data,kernel=(k,k),stride=(s,s),
                 pad=(p,p),pool_type=pool_type,out=out)
data, out_max_mx = get_pool_data_mxnet(c,n,k,p,s)
pool_mxnet('max',data,out_max_mx,k,p,s)

data, out_avg_mx = get_pool_data_mxnet(c,n,k,p,s)
pool_mxnet('avg',data,out_avg_mx,k,p,s)

In [8]:
import numpy as np

np.testing.assert_allclose(out_max_mx[0].asnumpy(), out_max.asnumpy(), atol=1e-5)
np.testing.assert_allclose(out_avg_mx[0].asnumpy(), out_avg.asnumpy(), atol=1e-5)