New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve performance for large dilation convolution #1887
Comments
Specific optimization is necessary for dilated conv2d and we do need some of them to skip unnecessary computations |
I think you are using CPU backends. Currently, we always create an intermediate buffer for dilation (compute_root). We can use compute_inline + unroll to eliminate multiplications of zeros and large intermediate buffers. Significant performance gain can be obtained here. |
I am using GPU backend. I don't quite understand how compute_inline + unroll can solve this problem. |
I will add a section to the tutorial later. |
The idea is to inline dilation rather than use a separate stage. Then we can unroll some axes and use the simplifier in tvm to remove the multiplications of zeros. This also avoids large intermediate buffers. You can follow the example below and change the implementation in TOPI. import tvm
import topi
from tvm.contrib.util import get_lower_ir
from topi.nn.util import get_pad_tuple
from topi.util import equal_const_int
# args of a conv2d
N, H, W, CO, CI, KH, KW, strides, padding = 1, 14, 14, 512, 512, 3, 3, (1, 1), (1, 1)
# large dilation
dilation = (4, 4)
def current_schedule():
"""The current bad schedule, as a reference"""
data = tvm.placeholder((N, CI, H, W), name='data')
raw_kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
# dilate as a separate stage before conv
dilated_kernel = topi.nn.dilate(raw_kernel, (1, 1) + dilation, name='DilatedKernel')
conv = topi.nn.conv2d_nchw(data, dilated_kernel, strides, padding, 'float32')
s = tvm.create_schedule([conv.op])
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
s[dilated_kernel].compute_inline()
AA = s.cache_read(pad_data, 'shared', conv)
BB = s.cache_read(dilated_kernel, 'shared', conv)
ci = s[conv].op.reduce_axis[0]
s[AA].compute_at(s[conv], ci)
s[BB].compute_at(s[conv], ci)
print(get_lower_ir(s))
def better_schedule():
"""The better schedule optimized for dilation"""
data = tvm.placeholder((N, CI, H, W), name='data')
raw_kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
dilate_args = (1, 1) + dilation
def dilate_kernel(*indices): # This function is the same as topi.nn.dilate, but inlined
not_zero = []
index_tuple = []
for i in range(len(dilate_args)):
if not equal_const_int(dilate_args[i], 1):
index_tuple.append(indices[i] // dilate_args[i])
not_zero.append((indices[i] % dilate_args[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, raw_kernel(*index_tuple), tvm.const(0.0, data.dtype))
return raw_kernel(*index_tuple)
kernel_h = (KH - 1) * dilation[0] + 1
kernel_w = (KW - 1) * dilation[1] + 1
# vanilla conv
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
out_height = (H - kernel_h + pad_top + pad_down) // strides[0] + 1
out_width = (W - kernel_h + pad_left + pad_right) // strides[1] + 1
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
pad_data = topi.nn.pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, CI), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_h), name='rx')
conv = tvm.compute((N, CO, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
pad_data[nn, rc, yy * strides[0] + ry, xx * strides[1] + rx]*
dilate_kernel(ff, rc, ry, rx), axis=[rc, ry, rx]))
# call inlined dilation function here
s = tvm.create_schedule([conv.op])
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
AA = s.cache_read(pad_data, 'shared', conv)
BB = s.cache_read(raw_kernel, 'shared', conv)
n, c, h, w = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
s[AA].compute_at(s[conv], ci)
s[BB].compute_at(s[conv], ci)
# use unroll + simpilier to eliminate multiplications of zeros
s[conv].unroll(kh)
s[conv].unroll(kw)
print(get_lower_ir(s))
print("Current Schedule")
current_schedule()
print("=================================")
print("Better Schedule")
better_schedule() Output Current Schedule
// attr [compute] storage_scope = "global"
allocate compute[float32 * 1 * 512 * 8 * 8]
// attr [pad_temp.shared] storage_scope = "shared"
allocate pad_temp.shared[float32 * 1 * 1 * 9 * 9]
// attr [DilatedKernel.shared] storage_scope = "shared"
allocate DilatedKernel.shared[float32 * 1 * 1 * 9 * 9]
produce compute {
for (ff, 0, 512) {
for (yy, 0, 8) {
for (xx, 0, 8) {
compute[((((ff*8) + yy)*8) + xx)] = 0.000000f
for (rc, 0, 512) {
produce pad_temp.shared {
for (ax2, 0, 9) {
for (ax3, 0, 9) {
pad_temp.shared[((ax2*9) + ax3)] = tvm_if_then_else((((((1 - ax2) <= yy) && (yy < (15 - ax2))) && ((1 - ax3) <= xx)) && (xx < (15 - ax3))), data[((((((yy*14) + xx) + (rc*196)) + (ax2*14)) + ax3) + -15)], 0.000000f)
}
}
}
// BAD: Large useless intermidate buffer
produce DilatedKernel.shared {
for (ax2, 0, 9) {
for (ax3, 0, 9) {
DilatedKernel.shared[((ax2*9) + ax3)] = tvm_if_then_else((((ax2 % 4) == 0) && ((ax3 % 4) == 0)), kernel[((((((ff*512) + rc)*3) + (ax2/4))*3) + (ax3/4))], 0.000000f)
}
}
}
// BAD: Many operands of multiplication are zero, which means these multiplications are useless.
for (ry, 0, 9) {
for (rx, 0, 9) {
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[((ry*9) + rx)]*DilatedKernel.shared[((ry*9) + rx)]))
}
}
}
}
}
}
}
=================================
Better Schedule
// attr [compute] storage_scope = "global"
allocate compute[float32 * 1 * 512 * 8 * 8]
// attr [pad_temp.shared] storage_scope = "shared"
allocate pad_temp.shared[float32 * 1 * 1 * 9 * 9]
// attr [kernel.shared] storage_scope = "shared"
allocate kernel.shared[float32 * 1 * 1 * 3 * 3]
produce compute {
for (ff, 0, 512) {
for (yy, 0, 8) {
for (xx, 0, 8) {
compute[((((ff*8) + yy)*8) + xx)] = 0.000000f
for (rc, 0, 512) {
produce pad_temp.shared {
for (ax2, 0, 9) {
for (ax3, 0, 9) {
pad_temp.shared[((ax2*9) + ax3)] = tvm_if_then_else((((((1 - ax2) <= yy) && (yy < (15 - ax2))) && ((1 - ax3) <= xx)) && (xx < (15 - ax3))), data[((((((yy*14) + xx) + (rc*196)) + (ax2*14)) + ax3) + -15)], 0.000000f)
}
}
}
// GOOD: no extra buffer
produce kernel.shared {
for (ax2, 0, 3) {
for (ax3, 0, 3) {
kernel.shared[((ax2*3) + ax3)] = kernel[((((((ff*512) + rc)*3) + ax2)*3) + ax3)]
}
}
}
// GOOD: minimal computation
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[0]*kernel.shared[0]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[4]*kernel.shared[1]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[8]*kernel.shared[2]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[36]*kernel.shared[3]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[40]*kernel.shared[4]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[44]*kernel.shared[5]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[72]*kernel.shared[6]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[76]*kernel.shared[7]))
compute[((((ff*8) + yy)*8) + xx)] = (compute[((((ff*8) + yy)*8) + xx)] + (pad_temp.shared[80]*kernel.shared[8]))
}
}
}
}
} |
@merrymercy Thank you very much for your detail explanation. I applied this change locally, the time cost of my model reduced from 300+ms to 50ms, great performance boost! However, I still have some questions here:
I limit the axis ry rx extent within original kernel size KH KW, and put dilation address calculation into pad_data array indexing. So there is no need to call dilate_kernel, and unrolling is not necessary either. This code is also more intuitive for developers to understand generated IR.
I post all my code here:
output:
|
Good observation. Cache_read can only cache a continuous range. This is the current limitation. We can add an explicit packing stage as a workaround. import tvm
import topi
from tvm.contrib.util import get_lower_ir
from topi.nn.util import get_pad_tuple
from topi.util import equal_const_int
# args of a conv2d
N, H, W, CO, CI, KH, KW, strides, padding = 1, 14, 14, 512, 512, 3, 3, (1, 1), (1, 1)
# large dilation
dilation = (4, 4)
def current_schedule():
"""The current bad schedule, as a reference"""
data = tvm.placeholder((N, CI, H, W), name='data')
raw_kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
# dilate as a separate stage before conv
dilated_kernel = topi.nn.dilate(raw_kernel, (1, 1) + dilation, name='DilatedKernel')
conv = topi.nn.conv2d_nchw(data, dilated_kernel, strides, padding, 'float32')
s = tvm.create_schedule([conv.op])
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
s[dilated_kernel].compute_inline()
AA = s.cache_read(pad_data, 'shared', conv)
BB = s.cache_read(dilated_kernel, 'shared', conv)
ci = s[conv].op.reduce_axis[0]
s[AA].compute_at(s[conv], ci)
s[BB].compute_at(s[conv], ci)
print(get_lower_ir(s))
def better_schedule():
"""The better schedule optimized for dilation"""
data = tvm.placeholder((N, CI, H, W), name='data')
raw_kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
dilate_args = (1, 1) + dilation
def dilate_kernel(*indices): # This function is the same as topi.nn.dilate, but inlined
not_zero = []
index_tuple = []
for i in range(len(dilate_args)):
if not equal_const_int(dilate_args[i], 1):
index_tuple.append(indices[i] // dilate_args[i])
not_zero.append((indices[i] % dilate_args[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, raw_kernel(*index_tuple), tvm.const(0.0, data.dtype))
return raw_kernel(*index_tuple)
kernel_h = (KH - 1) * dilation[0] + 1
kernel_w = (KW - 1) * dilation[1] + 1
# vanilla conv
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
out_height = (H - kernel_h + pad_top + pad_down) // strides[0] + 1
out_width = (W - kernel_h + pad_left + pad_right) // strides[1] + 1
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
pad_data = topi.nn.pad(data, pad_before, pad_after, name="pad_temp")
##### EXPLICIT PACKING #####
packed_data = tvm.compute((N, CO, out_height, out_width, KH, KW), lambda n, f, y, x, kh, kw:
pad_data[n, f, y * strides[0] + kh * dilation[0], x * strides[1] + kw * dilation[1]], name='packed_data')
rc = tvm.reduce_axis((0, CI), name='rc')
ry = tvm.reduce_axis((0, KH), name='ry')
rx = tvm.reduce_axis((0, KW), name='rx')
conv = tvm.compute((N, CO, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
packed_data[nn, rc, yy, xx, ry, rx] *
raw_kernel[ff, rc, ry, rx], axis=[rc, ry, rx]))
s = tvm.create_schedule([conv.op])
s[pad_data].compute_inline()
BB = s.cache_read(raw_kernel, 'shared', conv)
n, c, h, w = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
s[BB].compute_at(s[conv], ci)
s[packed_data].compute_at(s[conv], ci)
# use unroll + simpilier to eliminate multiplications of zeros
s[conv].unroll(kh)
s[conv].unroll(kw)
print(get_lower_ir(s))
print("Current Schedule")
current_schedule()
print("=================================")
print("Better Schedule")
better_schedule() output
|
@merrymercy Great idea! Based on your code, I made packed_data inline, and cache read from it. So the unnecessary memory loading is eliminated and I can still cache the data for reusing. |
Your contributions are welcome |
Eliminate unnecessary zero multiplications introduced by dilated kernel
@merrymercy I think there could be similar issue in int8 conv2d on CUDA. As for the limitation of cache_read, the redundant data load can be minimized if I do some tiling to get the holes inside the data used, right? |
Some network has convolution with large dilation, like deeplabs v3, may have 12, 24, 36 dilation value. Current implementation dilates the kernel to a bigger size then perform normal conv. But in case like dilation 24, the dilated kernel size will be 49x49, which is quite large and pretty likely to be the bottleneck of whole network.
The text was updated successfully, but these errors were encountered: