In [1]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32, Int64
import math
import numpy as np

In [2]:
BLOCK_SIZE = 32

In [3]:
@triton.jit
def scan(Y, nextY, stride, BLOCK_SIZE: tl.constexpr):
    pid_row = tl.program_id(0)

    for j in tl.static_range(BLOCK_SIZE):
        current_idx = pid_row * BLOCK_SIZE + j
        if current_idx - stride >= 0:
            Yj = tl.load(Y + current_idx)
            Yjminstride = tl.load(Y + current_idx - stride)
            
            tl.store(nextY + current_idx, Yj + Yjminstride)
        else:
            tl.store(nextY + current_idx, tl.load(Y + current_idx))

def triton_pref_sum(X):
    Y = torch.clone(X)
    Ynext = torch.empty_like(Y, device='cuda')
    n = X.shape[0]
    stride = 1
    for i in range(0, int(math.log2(n))):
        scan[(math.ceil(n / BLOCK_SIZE),)](Y, Ynext, stride, BLOCK_SIZE)
        stride *= 2
        Ynext, Y = Y, Ynext

    return Y

In [103]:
@triton.jit
def count(offset, X, under_pivot, over_pivot, pivot_idx, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    block = offset + pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    item = tl.load(X + block, mask=block<offset + N, other=float('nan'))
    num = tl.sum(tl.where(item < tl.load(X + offset + pivot_idx), 1, 0))

    tl.store(under_pivot + pid, num)
    tl.store(over_pivot + pid, tl.sum(tl.where(block<offset+N, 1, 0)) - num)

In [91]:
#: Int64[Tensor, '...']
@triton.jit
def triton_partition(offset, X, Y, pivot_idx, count_under_pivot, count_over_pivot, start_indices, start_indices2, total_before, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    
    startidx = tl.load(start_indices + pid).to(tl.int64) - tl.load(count_under_pivot + pid).to(tl.int64)
    startidx2 = tl.load(total_before).to(tl.int64) + tl.load(start_indices2 + pid).to(tl.int64) - tl.load(count_over_pivot + pid).to(tl.int64)

    #tl.device_print('total_before', total_before)
    #tl.device_print('start2', tl.load(start_indices2 + pid))
    #tl.device_print('thiscount', tl.load(count_over_pivot + pid))
    
    pivot = tl.load(X + offset + pivot_idx)
    
    for i in tl.static_range(BLOCK_SIZE):
        pos = pid * BLOCK_SIZE + i
        if pos < N:
            value = tl.load(X + offset + pos)
            #tl.device_print('pivot', pivot)
            #tl.device_print('value', value)
            if value < pivot:
                #tl.device_print('path1', startidx)
                tl.store(Y + offset + startidx, value)
                startidx += 1
            else:
                #tl.device_print('path2', startidx2)
                tl.store(Y + offset + startidx2, value)
                startidx2 += 1

    #tl.store(Y + pid, tl.load(start_indices2 + pid) - tl.load(count_over_pivot + pid))

In [109]:
def partition(X, Y, left, right):
    """
    left and right are inclusive
    """
    
    N = right - left + 1
    pivot_idx = np.random.randint(N)
    pivot = X[left + pivot_idx]

    #print('pivot is:', pivot)

    count_under_pivot = torch.zeros((math.ceil(N / BLOCK_SIZE)), device='cuda')
    count_over_pivot = torch.zeros((math.ceil(N / BLOCK_SIZE)), device='cuda')

    count[(math.ceil(N / BLOCK_SIZE),)](left, X, count_under_pivot, count_over_pivot, pivot_idx, N, BLOCK_SIZE)

    #print('under pivot:', count_under_pivot)
    #print('over pivot:', count_over_pivot)
    
    count_under_pivot = count_under_pivot.long()
    count_over_pivot = count_over_pivot.long()

    start_indices = triton_pref_sum(count_under_pivot)
    start_indices2 = triton_pref_sum(count_over_pivot)

    #print('start indices', start_indices)
    #print('start indices2', start_indices2)

    total_before = start_indices[-1]

    triton_partition[(math.ceil(N / BLOCK_SIZE),)]\
    (left, X, Y, pivot_idx, count_under_pivot, count_over_pivot, \
     start_indices, start_indices2, total_before, N, BLOCK_SIZE)

    return pivot, total_before.item()

In [7]:
X = torch.randint(low=0, high=100, size=(50,), device='cuda')
Y = torch.empty_like(X, device='cuda')

In [8]:
pivot, idx = partition(X, Y)

In [10]:
pivot

tensor(92, device='cuda:0')

In [11]:
Y[:idx]

tensor([50, 82, 26, 70, 35, 50, 13, 85, 22, 46, 38, 74, 45, 34,  2, 47, 21, 28,
        16, 84, 91,  6, 72, 32, 34, 23, 51, 36, 28, 45, 27, 51, 45, 31, 27, 40,
        57, 64, 36, 64, 59, 89, 85, 91, 70, 89, 51, 47], device='cuda:0')

In [12]:
Y[idx:]

tensor([92, 95], device='cuda:0')

In [113]:
from collections import deque
import time

def recursiveSort(X):
    Y = torch.empty_like(X, device='cuda')
    N = torch.numel(X)

    def recurse(left, right):
        print(X[left:right+1])
        pivot, numonleft = partition(X, Y, left, right)

        print('pivot=', pivot.item(), ': size=', numonleft, ' offset=', left)
        print(X, Y)
        
        X[left:right+1] = Y[left:right+1]
        
        #time.sleep(3)
        
        if numonleft > 2:
            recurse(left, left + numonleft - 1)
        if right - left + 1 - numonleft > 2:
            recurse(left + numonleft, right)

    recurse(0, N-1)

def sort(X):
    Y = torch.empty_like(X, device='cuda')
    N = torch.numel(X)
    
    q = deque()

    q.append((0, N - 1))

    while q:
        left, right = q.popleft()
        
        pivot, numonleft = partition(X, Y, left, right)

        #print(pivot, ':', numonleft, ' ', left)
        #print(Y[left:right+1])
        #time.sleep(3)
        
        X[left:right+1] = Y[left:right+1]

        if numonleft >= 2:
            q.append((left, left + numonleft - 1))
        if right - left + 1 - numonleft >= 2:
            q.append((left + numonleft, right))

In [114]:
X = torch.arange(50, 0, -1, device='cuda')#torch.randint(low=0, high=100, size=(50,), device='cuda')
print(X)
sort(X)
print(X)

tensor([50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33,
        32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15,
        14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1],
       device='cuda:0')
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
       device='cuda:0')


In [106]:
X = torch.arange(50, 0, -1, device='cuda')
Y = torch.empty_like(X, device='cuda')
print(X[10:])
partition(X, Y, 10, 49)
Y

tensor([40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
        22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,
         4,  3,  2,  1], device='cuda:0')
pivot is: tensor(22, device='cuda:0')
under pivot: tensor([13.,  8.], device='cuda:0')
over pivot: tensor([19.,  0.], device='cuda:0')
start indices tensor([13, 21], device='cuda:0')
start indices2 tensor([19, 19], device='cuda:0')


tensor([ 28,   0,  59, 151, 188, 184, 180, 176, 172, 168,  21,  20,  19,  18,
         17,  16,  15,  14,  13,  12,  11,  10,   9,   8,   7,   6,   5,   4,
          3,   2,   1,  40,  39,  38,  37,  36,  35,  34,  33,  32,  31,  30,
         29,  28,  27,  26,  25,  24,  23,  22], device='cuda:0')