Permalink
Browse files

intel graphics conv2d schedule fixed for input shapes (300*300) and (…

…512 * 512) (#1709)
  • Loading branch information...
Laurawly authored and tqchen committed Sep 13, 2018
1 parent 54f5e74 commit b25c15dee797514419fd07e8c298850ee61511c2
Showing with 27 additions and 28 deletions.
  1. +23 −28 topi/python/topi/intel_graphics/conv2d.py
  2. +4 −0 topi/tests/python/test_topi_conv2d_nchw.py
@@ -49,7 +49,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype)
oc_bn = 16
oc_bn = 1
kernel_shape = util.get_const_tuple(kernel.shape)
for oc_bn in range(16, 1, -1):
if kernel_shape[0] % oc_bn == 0:
break
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['kernel_layout'] = 'OIHW%do' % (oc_bn)
@@ -148,9 +152,6 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
@@ -190,6 +191,10 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
cshape = (batch, out_channel // nv, c_h, c_w, nv)
conv = tvm.compute(
@@ -263,17 +268,8 @@ def _schedule_cl_spatialpack_NCHWc(s, op):
s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
@@ -396,9 +392,6 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
@@ -432,13 +425,21 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
c_h = out_height
c_w = out_width
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
if not out_height % block_h == 0:
c_h = (out_height // block_h + 1) * block_h
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
nv = 16
if not num_filter % nv == 0:
num_filter = (num_filter // nv + 1) * nv
out_channel = num_filter
cshape = (batch, out_channel // nv, c_h, c_w, nv)
kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv)
@@ -520,14 +521,8 @@ def _schedule_cl_spatialpack(s, op):
s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
@@ -161,6 +161,10 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1)
if __name__ == "__main__":

0 comments on commit b25c15d

Please sign in to comment.