In [1]:
import numpy as np

def classic_bitonic_sort_py(arr: np.ndarray, low: int, cnt: int, direction: bool):
    """
    Recursive function for Bitonic Sort.
    
    Parameters:
    arr       : The input array to be sorted.
    low       : The starting index of the segment to sort.
    cnt       : The number of elements in the segment.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    if cnt < 2:
        return
    k = cnt // 2
    # First half sorted in the given direction
    classic_bitonic_sort_py(arr, low, k, True)  # Ascending
    # Second half sorted in the opposite direction
    classic_bitonic_sort_py(arr, low + k, k, False)  # Descending
    classic_bitonic_merge_py(arr, low, cnt, direction)


def classic_bitonic_merge_py(arr: np.ndarray, low: int, cnt: int, direction: bool):
    """
    Merge function for Bitonic Sort (implementation not provided here).

    Parameters:
    arr       : The input array to merge.
    low       : The starting index of the segment to merge.
    cnt       : The number of elements in the segment.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    if cnt < 2:
        return
    k = cnt // 2
    # Compare and swap elements to create a bitonic sequence
    for i in range(low, low + k):
        if direction == (arr[i] > arr[i + k]):
            # Swap elements if they are not in the desired order
            arr[i], arr[i + k] = arr[i + k], arr[i]
    # Recursively merge the two halves
    classic_bitonic_merge_py(arr, low, k, direction)
    classic_bitonic_merge_py(arr, low + k, k, direction)

def classic_bitonic_sort_main_py(arr: np.ndarray, direction: bool=True):
    """
    Main function to sort an array using Bitonic Sort.

    Parameters:
    arr       : The input array to be sorted.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    n = len(arr)
    # Check if the length is a power of two
    if (n & (n - 1)) != 0:
        raise ValueError("Input length must be a power of 2.")
    # Start the recursive bitonic sort
    classic_bitonic_sort_py(arr, 0, n, direction)


In [2]:
import numba

@numba.njit(parallel=False)
def classic_bitonic_sort_cpu_njit(arr: np.ndarray, low: int, cnt: int, direction: bool):
    """
    Recursive function for Bitonic Sort.
    
    Parameters:
    arr       : The input array to be sorted.
    low       : The starting index of the segment to sort.
    cnt       : The number of elements in the segment.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    if cnt < 2:
        return
    k = cnt // 2
    # First half sorted in the given direction
    classic_bitonic_sort_cpu_njit(arr, low, k, True)  # Ascending
    # Second half sorted in the opposite direction
    classic_bitonic_sort_cpu_njit(arr, low + k, k, False)  # Descending
    classic_bitonic_merge_cpu_njit(arr, low, cnt, direction)

@numba.njit(parallel=True)
def classic_bitonic_merge_cpu_njit(arr: np.ndarray, low: int, cnt: int, direction: bool):
    """
    Merge function for Bitonic Sort (implementation not provided here).

    Parameters:
    arr       : The input array to merge.
    low       : The starting index of the segment to merge.
    cnt       : The number of elements in the segment.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    if cnt < 2:
        return
    k = cnt // 2
    # Compare and swap elements to create a bitonic sequence
    for i in numba.prange(low, low + k):
        if direction == (arr[i] > arr[i + k]):
            # Swap elements if they are not in the desired order
            arr[i], arr[i + k] = arr[i + k], arr[i]
    # Recursively merge the two halves
    classic_bitonic_merge_cpu_njit(arr, low, k, direction)
    classic_bitonic_merge_cpu_njit(arr, low + k, k, direction)

@numba.njit(parallel=False)
def classic_bitonic_sort_main_cpu_njit(arr: np.ndarray, direction: bool=True):
    """
    Main function to sort an array using Bitonic Sort.

    Parameters:
    arr       : The input array to be sorted.
    direction : The sorting direction (1 for ascending, 0 for descending).
    """
    n = len(arr)
    # Check if the length is a power of two
    if (n & (n - 1)) != 0:
        raise ValueError("Input length must be a power of 2.")
    # Start the recursive bitonic sort
    classic_bitonic_sort_cpu_njit(arr, 0, n, direction)


In [3]:
original_arr = np.array([3, 7, 4, 8, 6, 2, 1, 5])
arr = np.copy(original_arr)
print("Original array:", arr)
# Sort in ascending order
classic_bitonic_sort_main_py(arr)
print("Sorted array (ascending):", arr)


Original array: [3 7 4 8 6 2 1 5]
Sorted array (ascending): [1 2 3 4 5 6 7 8]


In [4]:
arr = np.copy(original_arr)
print("Original array:", arr)
# Sort in ascending order
classic_bitonic_sort_main_cpu_njit(arr)
print("Sorted array (ascending):", arr)


Original array: [3 7 4 8 6 2 1 5]
Sorted array (ascending): [1 2 3 4 5 6 7 8]


In [6]:
import numpy as np
import numba.cuda

@numba.cuda.jit
def bitonic_sort_step(arr: numba.cuda.device_array, n: np.uint32, stage: np.uint32, step: np.uint32): # type: ignore
    """
    Perform one step of the Bitonic Sort on GPU.

    Parameters:
    arr   : cuda.device_array - Global memory array to be sorted.
    n     : int               - Total number of elements in the array.
    stage : int               - The current stage of sorting.
    step  : int               - The current step within the stage.
    """
    # Thread and block indices
    idx = numba.cuda.grid(1) # type: ignore

    # Global memory bounds check
    if idx >= n:
        return

    # Compute partner index for compare-and-swap
    partner = idx ^ step
    # Ensure valid partner within bounds and only process "ascending direction" pairs
    if partner <= idx or partner >= n:
        return
    # Determine the direction of sorting
    direction = ((idx & stage) == 0)

    # Compare-and-swap
    if (direction and arr[idx] > arr[partner]) or (not direction and arr[idx] < arr[partner]):
        # Swap elements
        arr[idx], arr[partner] = arr[partner], arr[idx]


def gpu_bitonic_sort(arr: np.ndarray) -> np.ndarray:
    """
    Main function to perform Bitonic Sort on GPU using Numba CUDA.

    Parameters:
    arr : np.ndarray - Input array to be sorted (1D array, must be a power of 2 in length).

    Returns:
    np.ndarray - Sorted array (copied back to the host).
    """
    n = len(arr)
    # Check if input size is a power of 2
    if (n & (n - 1)) != 0:
        n2 = 1
        while n2 < n:
            n2 <<= 1
        arr = np.concatenate([arr, np.full(n2-n, fill_value=np.inf)])
    else:
        n2 = n
    n2 = np.uint32(n2)
    # Copy array to device
    d_arr: numba.cuda.device_array = numba.cuda.to_device(arr) # type: ignore

    # Define thread and block dimensions
    threads_per_block: int = 256
    blocks_per_grid: int = (n2 + threads_per_block - 1) // threads_per_block

    # Iterative Bitonic Sort
    stage: np.uint32 = np.uint32(2)
    while stage <= n2:
        step: np.uint32 = stage // 2
        while step > 0:
            # Launch kernel for this step
            bitonic_sort_step[blocks_per_grid, threads_per_block](d_arr, n2, stage, step) # type: ignore
            step //= 2
        stage *= 2

    # Copy sorted array back to host
    sorted_arr: np.ndarray = d_arr[:n].copy_to_host()
    return sorted_arr

# Example Usage
# Input array (must be power of 2)
np.random.seed(0)
n = 60
data = (100 * np.random.rand(n)).astype(np.float32)

print("Original array:", data)
sorted_data: np.ndarray = gpu_bitonic_sort(data)
print("Sorted array:", sorted_data)
print(np.all(sorted_data[1:] >= sorted_data[:-1]))


Original array: [54.88135   71.518936  60.276337  54.48832   42.36548   64.58941
 43.75872   89.1773    96.36628   38.34415   79.1725    52.889492
 56.804455  92.55966    7.1036057  8.71293    2.0218399 83.261986
 77.815674  87.00121   97.86183   79.915855  46.147938  78.05292
 11.827442  63.992104  14.335329  94.46689   52.184834  41.466194
 26.45556   77.42337   45.615032  56.843395   1.87898   61.76355
 61.20957   61.6934    94.37481   68.18203   35.95079   43.703194
 69.76312    6.0225472 66.676674  67.06379   21.038256  12.89263
 31.542835  36.37108   57.019676  43.860153  98.83739   10.204481
 20.887676  16.13095   65.31083   25.32916   46.631077  24.442558 ]




Sorted array: [ 1.87898004  2.02183986  6.02254725  7.10360575  8.71292973 10.20448112
 11.82744217 12.89262962 14.33532906 16.13095093 20.88767624 21.03825569
 24.44255829 25.32916069 26.45556068 31.54283524 35.95079041 36.37107849
 38.34415054 41.46619415 42.36547852 43.70319366 43.7587204  43.8601532
 45.6150322  46.14793777 46.63107681 52.18483353 52.88949203 54.4883194
 54.88135147 56.8044548  56.84339523 57.01967621 60.27633667 61.20957184
 61.69340134 61.7635498  63.99210358 64.58940887 65.31082916 66.67667389
 67.06378937 68.18202972 69.76312256 71.51893616 77.42337036 77.81567383
 78.05291748 79.17250061 79.91585541 83.26198578 87.00121307 89.1772995
 92.55966187 94.37480927 94.46688843 96.3662796  97.86183167 98.83738708]
True


In [7]:
import cupy
print(data)
numba_arr: numba.cuda.device_array = numba.cuda.to_device(data) # type: ignore
cp_arr = cupy.asarray(numba_arr)
sorted_idx = cupy.argsort(cp_arr)
print(f'{sorted_idx.dtype=}')
res = numba.cuda.as_cuda_array(cp_arr[sorted_idx]).copy_to_host()
print(res)

[54.88135   71.518936  60.276337  54.48832   42.36548   64.58941
 43.75872   89.1773    96.36628   38.34415   79.1725    52.889492
 56.804455  92.55966    7.1036057  8.71293    2.0218399 83.261986
 77.815674  87.00121   97.86183   79.915855  46.147938  78.05292
 11.827442  63.992104  14.335329  94.46689   52.184834  41.466194
 26.45556   77.42337   45.615032  56.843395   1.87898   61.76355
 61.20957   61.6934    94.37481   68.18203   35.95079   43.703194
 69.76312    6.0225472 66.676674  67.06379   21.038256  12.89263
 31.542835  36.37108   57.019676  43.860153  98.83739   10.204481
 20.887676  16.13095   65.31083   25.32916   46.631077  24.442558 ]
sorted_idx.dtype=dtype('int64')
[ 1.87898    2.0218399  6.0225472  7.1036057  8.71293   10.204481
 11.827442  12.89263   14.335329  16.13095   20.887676  21.038256
 24.442558  25.32916   26.45556   31.542835  35.95079   36.37108
 38.34415   41.466194  42.36548   43.703194  43.75872   43.860153
 45.615032  46.147938  46.631077  52.184834  52