Skip to content

Commit

Permalink
[CUDA] Parallel Cuda Mergesort (apache#7099)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and masahi committed Dec 24, 2020
1 parent 6184a6d commit 9970cfe
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 82 deletions.
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Expand Up @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
Expand Down
292 changes: 224 additions & 68 deletions python/tvm/topi/cuda/sort.py
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return
"""Sort related operators """
import tvm
from tvm import te
Expand Down Expand Up @@ -62,28 +62,43 @@ def traverse(op):
return s


def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
def sort_ir(
data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
Buffer of input data. Data will be sorted in place.
output : Buffer
Output buffer of indicies of sorted tensor with same shape as data.
values_out : Buffer
Output buffer of values of sorted tensor with same shape as data.
values_out_swap : Buffer
Output buffer of values with same shape as data to use as swap.
axis : Int
Axis long which to sort the input tensor.
is_ascend : Boolean
Whether to sort in ascending or descending order.
indicess_out : Buffer
Output buffer of indices of sorted tensor with same shape as data.
indices_out_swap : Buffer
Output buffer of indices with same shape as data to use as swap.
Returns
-------
stmt : Stmt
The result IR statement.
"""

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

axis_mul_before = 1
axis_mul_after = 1
shape = data.shape
Expand All @@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
axis_mul_before *= value
elif i > axis:
axis_mul_after *= value
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
values_out = ib.buffer_ptr(values_out)
values_out_swap = ib.buffer_ptr(values_out_swap)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
nthread_tx = max_threads
nthread_bx = shape[axis] // max_threads + 1
assert indices_out_swap is not None
indices_out_swap = ib.buffer_ptr(indices_out_swap)

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
if indices_out is not None:
temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
# Set up threading
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after

# Copy the data to initial output
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
values_out[idx] = data[idx]
if indices_out is not None:
indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)

## we are looping over the array doing mergesort from the bottom up.
## The outer loop runs on the host and launches a cuda kernel for each iteration
## of the algorithm.
## The basic idea is that at iteration 0, each thread does sort on 2 elements.
## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
## to deal with 4 total elements.
## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
## On the final iteration of the algorithm, one thread will merge two sorted lists
## to sort the entire array
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
# Define and launch the cuda kernel
with ib.new_scope():
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
# Reduce the number of blocks as the work per thread grows
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)

def compare(a, b):
"""
Compare a and b in proper ascending or descending order
"""
if is_ascend:
out = a <= b
else:
out = b <= a
return out

def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
"""
Merge the two sections of the array assigned to this thread
"""
# pylint: disable=arguments-out-of-order
# initialize iterators
i[0] = start
j[0] = middle
# set up indexes
base_idx = by * shape[axis] * axis_mul_after + bz
# iterate over the output loop
with ib.for_range(0, end - start) as k:
i_idx = base_idx + i[0] * axis_mul_after
j_idx = base_idx + j[0] * axis_mul_after
k_idx = base_idx + (k + start) * axis_mul_after

def swap_values(source, dest, source_idx, dest_idx):
def assign_i():
"""assign i value to current output"""
dest[k_idx] = source[i_idx]
if indices_out is not None:
dest_idx[k_idx] = source_idx[i_idx]
i[0] += 1

def assign_j():
"""assign j value to current output"""
dest[k_idx] = source[j_idx]
if indices_out is not None:
dest_idx[k_idx] = source_idx[j_idx]
j[0] += 1

## if both of the iterators are in range
with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
# compare them and insert whichever is next into the output
with ib.if_scope(compare(source[i_idx], source[j_idx])):
assign_i()
with ib.else_scope():
assign_j()
# otherwise, simply copy the remainder of the valid iterator to the output
with ib.else_scope():
with ib.if_scope(i[0] < middle):
assign_i()
with ib.else_scope():
assign_j()

# Switch which input is the source and which is the destination each iteration
with ib.if_scope(even):
swap_values(source, dest, source_idx, dest_idx)
with ib.else_scope():
swap_values(dest, source, dest_idx, source_idx)

def mergesort(source, dest, source_idx, dest_idx, size, width, even):
# calculate the start, mid, and end points of this section
start[0] = width * tid
with ib.if_scope(start[0] < size):
middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
end[0] = tvm.te.min(start[0] + width, size)
## merge the start->middle and middle->end arrays
bottom_up_merge(
source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even
)

with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
base_idx = i * shape[axis] * axis_mul_after + j
# Call the kernel
mergesort(
values_out,
values_out_swap,
indices_out,
indices_out_swap,
shape[axis],
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)

## if the final sorted data ended up in the swap, copy it to the real output
with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
idx = (by * shape[axis] + tid) * axis_mul_after + bz
values_out[idx] = values_out_swap[idx]
if indices_out is not None:
indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast(
tid, indices_out.dtype
)
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
if is_ascend:
cond = tvm.tir.all(
2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] > values_out[offset + axis_mul_after],
)
else:
cond = tvm.tir.all(
2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] < values_out[offset + axis_mul_after],
)
with ib.if_scope(cond):
temp_data[0] = values_out[offset]
values_out[offset] = values_out[offset + axis_mul_after]
values_out[offset + axis_mul_after] = temp_data[0]
if indices_out is not None:
temp_index[0] = indices_out[offset]
indices_out[offset] = indices_out[offset + axis_mul_after]
indices_out[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
indices_out[idx] = indices_out_swap[idx]

return ib.get()

Expand Down Expand Up @@ -336,14 +469,13 @@ def sort(data, axis=-1, is_ascend=1):
out : tvm.te.Tensor
The output of this function.
"""
dtype = "float32"
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
out_buffers=[value_buf, value_buf_swap],
name="sort_gpu",
tag="sort_gpu",
)[0]
Expand Down Expand Up @@ -449,12 +581,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
)
else:
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "value_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
axis,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
name="argsort_gpu",
tag="argsort_gpu",
)[1]
Expand Down Expand Up @@ -564,25 +708,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
values_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "values_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
if ret_type == "values":
output = te.extern(
[data.shape],
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend),
out_buffers=[values_buf],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
out_buffers=[values_buf, values_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)
)[0]
else:
output = te.extern(
[data.shape, data.shape],
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[values_buf, indices_buf],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
axis,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)
)[0:2]
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
Expand Down

0 comments on commit 9970cfe

Please sign in to comment.