Here, I will use Numba and Triton lmao

In [1]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32, Int64
import math
import numpy as np

In [2]:
BLOCK_SIZE = 32

In [3]:
@triton.jit
def scan(Y, nextY, stride, BLOCK_SIZE: tl.constexpr):
    pid_row = tl.program_id(0)

    for j in tl.static_range(BLOCK_SIZE):
        current_idx = pid_row * BLOCK_SIZE + j
        if current_idx - stride >= 0:
            Yj = tl.load(Y + current_idx)
            Yjminstride = tl.load(Y + current_idx - stride)
            
            tl.store(nextY + current_idx, Yj + Yjminstride)
        else:
            tl.store(nextY + current_idx, tl.load(Y + current_idx))

In [4]:
def triton_pref_sum(X):
    Y = torch.clone(X)
    Ynext = torch.empty_like(Y, device='cuda')
    n = X.shape[0]
    stride = 1
    for i in range(0, int(math.log2(n))):
        scan[(math.ceil(n / BLOCK_SIZE),)](Y, Ynext, stride, BLOCK_SIZE)
        stride *= 2
        Ynext, Y = Y, Ynext

    return Y


In [5]:
N = 50
X = torch.randint(low=0, high=100, size=(N,), device='cuda')
X

tensor([18, 16, 74, 86, 36, 11, 68, 94, 68, 52, 87, 30, 83, 18, 71, 26, 17, 89,
        72,  5, 70, 49, 77, 43, 76, 18,  6, 98, 76, 26, 42, 36, 33,  1, 13, 47,
        37, 74, 54, 36, 26, 85, 60, 38, 69, 31, 26, 66, 23, 57],
       device='cuda:0')

In [6]:
pivot_idx = np.random.randint(N)
pivot_idx

13

In [7]:
X[pivot_idx]

tensor(18, device='cuda:0')

In [8]:
count_under_pivot = torch.zeros((math.ceil(N / BLOCK_SIZE)), device='cuda')
count_under_pivot
count_over_pivot = torch.zeros((math.ceil(N / BLOCK_SIZE)), device='cuda')
count_over_pivot

tensor([0., 0.], device='cuda:0')

In [9]:
@triton.jit
def count(X, under_pivot, over_pivot, pivot_idx, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    item = tl.load(X + block, mask=block<N, other=float('nan'))
    num = tl.sum(tl.where(item < tl.load(X + pivot_idx), 1, 0))

    #tl.device_print("stuff", num)

    tl.store(under_pivot + pid, num)
    tl.store(over_pivot + pid, tl.sum(tl.where(block<N, 1, 0)) - num)

In [10]:
count[(math.ceil(N / BLOCK_SIZE),)](X, count_under_pivot, count_over_pivot, pivot_idx, N, BLOCK_SIZE)

<triton.compiler.compiler.CompiledKernel at 0x7fddec38a420>

In [11]:
torch.where(X < X[pivot_idx], 1, 0)

tensor([0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0], device='cuda:0')

In [12]:
count_under_pivot = count_under_pivot.long()
count_under_pivot

tensor([5, 2], device='cuda:0')

In [13]:
count_over_pivot = count_over_pivot.long()
count_over_pivot

tensor([27, 16], device='cuda:0')

In [14]:
start_indices = triton_pref_sum(count_under_pivot)
start_indices

start_indices2 = triton_pref_sum(count_over_pivot)
start_indices2

tensor([27, 43], device='cuda:0')

In [15]:
#: Int64[Tensor, '...']
@triton.jit
def partition(X, Y, pivot_idx, count_under_pivot, count_over_pivot, start_indices, start_indices2, total_before, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    
    startidx = tl.load(start_indices + pid).to(tl.int64) - tl.load(count_under_pivot + pid).to(tl.int64)
    startidx2 = tl.load(total_before).to(tl.int64) + tl.load(start_indices2 + pid).to(tl.int64) - tl.load(count_over_pivot + pid).to(tl.int64)

    #tl.device_print('total_before', total_before)
    #tl.device_print('start2', tl.load(start_indices2 + pid))
    #tl.device_print('thiscount', tl.load(count_over_pivot + pid))
    
    pivot = tl.load(X + pivot_idx)
    
    for i in tl.static_range(BLOCK_SIZE):
        pos = pid * BLOCK_SIZE + i
        if pos < N:
            value = tl.load(X + pos)
            #tl.device_print('pivot', pivot)
            #tl.device_print('value', value)
            if value < pivot:
                #tl.device_print('path1', startidx)
                tl.store(Y + startidx, value)
                startidx += 1
            else:
                #tl.device_print('path2', startidx2)
                tl.store(Y + startidx2, value)
                startidx2 += 1

    #tl.store(Y + pid, tl.load(start_indices2 + pid) - tl.load(count_over_pivot + pid))

In [16]:
Y = torch.empty_like(X, device='cuda')
#Y = torch.empty((1024,), device='cuda')

In [17]:
total_before = start_indices[start_indices.shape[0] - 1]
total_before

tensor(7, device='cuda:0')

In [18]:
partition[(math.ceil(N / BLOCK_SIZE),)]\
    (X, Y, pivot_idx, count_under_pivot, count_over_pivot, \
     start_indices, start_indices2, total_before, N, BLOCK_SIZE)

<triton.compiler.compiler.CompiledKernel at 0x7fddec4e4470>

In [19]:
Y

tensor([16, 11, 17,  5,  6,  1, 13, 18, 74, 86, 36, 68, 94, 68, 52, 87, 30, 83,
        18, 71, 26, 89, 72, 70, 49, 77, 43, 76, 18, 98, 76, 26, 42, 36, 33, 47,
        37, 74, 54, 36, 26, 85, 60, 38, 69, 31, 26, 66, 23, 57],
       device='cuda:0')

In [20]:
%env CUDA_LAUNCH_BLOCKING=1
%env TORCH_USE_CUDA_DSA=1

env: CUDA_LAUNCH_BLOCKING=1
env: TORCH_USE_CUDA_DSA=1


In [21]:
X[pivot_idx]

tensor(18, device='cuda:0')

In [22]:
Y

tensor([16, 11, 17,  5,  6,  1, 13, 18, 74, 86, 36, 68, 94, 68, 52, 87, 30, 83,
        18, 71, 26, 89, 72, 70, 49, 77, 43, 76, 18, 98, 76, 26, 42, 36, 33, 47,
        37, 74, 54, 36, 26, 85, 60, 38, 69, 31, 26, 66, 23, 57],
       device='cuda:0')