In [6]:
!pip install Ninja



In [7]:
import os
import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load_inline

In [8]:
import os
import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load_inline

cuda_src = '''
__global__
void forward_kernel(float* A, int len, int j) {
  int t = blockIdx.x*blockDim.x + threadIdx.x;
  int l = t ^ j;
  int dir = t & len;

  if(t < l){
    if(((dir == 0) && A[t] > A[l]) || ((dir !=0) && A[t] < A[l])){
        float temp = A[t];
        A[t] = A[l];
        A[l] = temp;
    }
  }

  return;
}

torch::Tensor forward(torch::Tensor A) {
    int n = static_cast<int>(A.size(0));

    torch::Tensor O = A.clone();

    int num_threads = 1024;
    int num_blocks = n/1024;
    dim3 grid_dim(num_blocks);
    dim3 block_dim(num_threads);

    for(int len=2; len <= n; len *= 2){
        for(int j=len/2; j>0; j /= 2){
            forward_kernel<<<grid_dim, block_dim>>>(
                O.data_ptr<float>(),
                len,
                j
            );
        }
    }

    return O;
}
'''

cpp_src = 'torch::Tensor forward(torch::Tensor A);'

build_dir = 'cuda'
if not os.path.exists(build_dir):
    os.mkdir(build_dir)

os.environ['TORCH_CUDA_ARCH_LIST'] = "7.5"

bit_sort = load_inline(
    name='bit_sort',
    cpp_sources=cpp_src,
    cuda_sources=cuda_src,
    functions=['forward'],
    with_cuda=True,
    extra_cuda_cflags=["-arch=sm_75"],
    build_directory=f'./{build_dir}'
)

In [9]:
def generate_input(size: int, seed: int) -> torch.Tensor:
    """
    Generates random input tensor where elements are drawn from different distributions.

    Args:
        size: Total size of the final 1D tensor
        seed: Base seed for random generation

    Returns:
        1D tensor of size `size` containing flattened values from different distributions
    """
    # Calculate dimensions for a roughly square 2D matrix
    rows = int(size ** 0.5)  # Square root for roughly square shape
    cols = (size + rows - 1) // rows  # Ceiling division to ensure total size >= requested size

    gen = torch.Generator(device='cuda')
    result = torch.empty((rows, cols), device='cuda', dtype=torch.float32)

    # Different seed for each row!
    for i in range(rows):
        row_seed = seed + i
        gen.manual_seed(row_seed)

        # Generate values for this row with mean=row_seed
        result[i, :] = torch.randn(cols, device='cuda', dtype=torch.float32, generator=gen) + row_seed

    # Flatten and trim to exact size requested
    return result.flatten()[:size].contiguous()

In [10]:
sizes = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 16777216]
seed = 4052
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)

print('=== profiling vector add === ')
for size in sizes:
  print(f"------------ vector add on size {size} ------------------------------------")
  a = generate_input(size, seed)
  ref = torch.sort(a)[0]
  with torch.autograd.profiler.profile(use_cuda=True) as prof:
    result = bit_sort.forward(a)
  print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

#   print(result)
#   print(ref)
  print('attn values sanity check:', torch.allclose(result, ref, rtol=0, atol=1e-02))

=== profiling vector add === 
------------ vector add on size 1024 ------------------------------------
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::clone        11.71%      31.880us        27.89%      75.923us      75.923us      39.000us        46.99%      83.000us      83.000us             1  
              aten::copy_         3.69%      10.042us         8.23%      22.414us      22.414us      26.000us        31.33%      26.000us      26.000us             1  
      aten::empty_strided         5.12%      13.939us  

  with torch.autograd.profiler.profile(use_cuda=True) as prof:


-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::clone         0.01%      25.207us         0.04%      73.172us      73.172us      10.000us         1.80%     557.000us     557.000us             1  
              aten::copy_         0.01%      10.397us         0.01%      22.963us      22.963us     545.000us        97.85%     545.000us     545.000us             1  
      aten::empty_strided         0.01%      17.265us         0.01%      17.265us      17.265us       2.000us         0.36%       2.000us       2.000us        