In [1]:
import ray
import numpy as np
import time

In [2]:
num_cores = 8
num_partition_blocks = num_cores
num_sorting_blocks = num_cores
num_samples_for_pivots = num_partition_blocks * 25
array_len = 100000000

In [3]:
def compute_pivots(values, num_samples, num_partitions):
    """Sampling a subsection of the array and chooses partition pivots"""
    samples = values[np.random.randint(0, len(values), size=num_samples)]
    samples = np.sort(samples)
    pivot_indices = np.arange(1, num_partitions) * (len(samples) //
                                                    num_partitions)
    return samples[pivot_indices]

In [4]:
@ray.remote(num_return_vals=num_sorting_blocks)
def partition_block(block, pivots):
    """Sort and partition the array further by the given pivots."""
    sorted = np.sort(block)
    partition_indices = sorted.searchsorted(pivots)
    return np.split(sorted, partition_indices)

In [5]:
@ray.remote
def merge_and_sort(*partition):
    """Concatenate the arrays given and sort afterwards"""
    return np.sort(np.concatenate(partition))

In [6]:
ray.init(num_cpus=num_cores)

2019-05-14 12:49:10,596	INFO node.py:469 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-05-14_12-49-10_4412/logs.
2019-05-14 12:49:10,707	INFO services.py:409 -- Waiting for redis server at 127.0.0.1:16709 to respond...
2019-05-14 12:49:10,825	INFO services.py:409 -- Waiting for redis server at 127.0.0.1:58219 to respond...
2019-05-14 12:49:10,830	INFO services.py:806 -- Starting Redis shard with 3.44 GB max memory.
2019-05-14 12:49:10,844	INFO node.py:483 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-05-14_12-49-10_4412/logs.
2019-05-14 12:49:10,847	INFO services.py:1441 -- Starting the Plasma object store with 5.15 GB memory using /tmp.


{'node_ip_address': '10.142.32.61',
 'object_store_address': '/tmp/ray/session_2019-05-14_12-49-10_4412/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2019-05-14_12-49-10_4412/sockets/raylet',
 'redis_address': '10.142.32.61:16709',
 'webui_url': None}

In [14]:
# Generate a random array.
values = np.random.randint(0, 256, size=array_len, dtype=np.uint8)

# Begin timing the parallel sort example.
parallel_sort_start = time.time()

# Generate pivots to use as range partitions.
pivots = compute_pivots(values, num_samples_for_pivots,
                        num_partition_blocks)

# Split the array into roughly equal partitions, which we will further
# partition into ranges by pivots in parallel.
blocks = np.array_split(values, num_partition_blocks)
partition_ids = [partition_block.remote(block, pivots) for block in blocks]
partition_ids = list(map(list, zip(*partition_ids)))

sorted_ids = [merge_and_sort.remote(*partition_ids[id]) for id in
              range(len(partition_ids))]
parallel_sorted = np.concatenate(ray.get(sorted_ids))

parallel_sort_end = time.time()
print("Parallel sort took {} seconds."
      .format(parallel_sort_end - parallel_sort_start))

Parallel sort took 1.3176541328430176 seconds.


In [17]:
# Run a serial sort as an accuracy check and time comparison.
serial_sort_start = time.time()
serial_sorted = np.sort(values)
serial_sort_end = time.time()
print("Serial sort took {} seconds."
      .format(serial_sort_end - serial_sort_start))

Serial sort took 4.165165185928345 seconds.


In [12]:
# Check that we sorted the array properly.
assert np.allclose(parallel_sorted, serial_sorted)
print("Parallel sort successful and correct.")

Parallel sort successful and correct.
