In [1]:
import os, re
import torch
import torch.nn as nn
import typing
import numpy as np
from types import SimpleNamespace as ns
from utils import load_cuda, cuda_begin, cdiv
import cuda_ext

In [2]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [3]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.manual_seed(42)

<torch._C.Generator at 0x7fcbf0528650>

# CUDA

## argmax

In [4]:
m = torch.rand(32,64,128,128).cuda()

In [5]:
%%timeit -n 10
res = torch.argmax(m, dim=0)

361 µs ± 183 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
%%timeit -n 10
cuda_res = cuda_ext.argmax(m, dim=0)

374 µs ± 205 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
res = torch.argmax(m, dim=0)
cuda_res = cuda_ext.argmax(m, dim=0)
print(res.dtype, cuda_res.dtype)
torch.isclose(res.to(torch.int8),cuda_res).all()

torch.int64 torch.int8


tensor(True, device='cuda:0')

In [8]:
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,  # Track CPU activity
        torch.profiler.ProfilerActivity.CUDA  # Track GPU activity (for memory and time)
    ],
    record_shapes=True,  # Records shapes of operator inputs
    profile_memory=True, # Tracks memory usage
    with_stack=True      # Captures stack traces (optional, useful for deep debugging)
) as prof:
    result = cuda_ext.argmax(m, dim=0)
    # res = torch.argmax(m, dim=0)
    torch.cuda.synchronize()

print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",  # Sort by GPU memory usage
    row_limit=10                       # Limit rows (useful for large reports)
))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::empty        87.82%       3.020ms        87.82%       3.020ms       3.020ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b       1.00 Mb       1.00 M

STAGE:2024-10-13 17:05:11 19771:19771 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-10-13 17:05:11 19771:19771 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-10-13 17:05:11 19771:19771 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [9]:
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,  # Track CPU activity
        torch.profiler.ProfilerActivity.CUDA  # Track GPU activity (for memory and time)
    ],
    record_shapes=True,  # Records shapes of operator inputs
    profile_memory=True, # Tracks memory usage
    with_stack=True      # Captures stack traces (optional, useful for deep debugging)
) as prof:
    # result = cuda_ext.argmax(m, dim=0)
    res = torch.argmax(m, dim=0)
    torch.cuda.synchronize()

print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",  # Sort by GPU memory usage
    row_limit=10                       # Limit rows (useful for large reports)
))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::argmax        91.41%       3.416ms        99.89%       3.733ms       3.733ms     272.000us       100.00%     272.000us     272.000us           0 b           0 b       8.00 Mb       8.00 M

STAGE:2024-10-13 17:05:12 19771:19771 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-10-13 17:05:12 19771:19771 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-10-13 17:05:12 19771:19771 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


## layernorm

In [4]:
sentence_length, embedding_dim = 64, 64
embedding = torch.randn(sentence_length, embedding_dim).cuda()
layer_norm = nn.LayerNorm(embedding_dim).cuda()

In [8]:
%%timeit -n 10
with torch.no_grad():
    torch_result = layer_norm(embedding)

47.6 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit -n 10
cuda_result = cuda_ext.layernorm_welford(embedding)

58.3 µs ± 26.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%%timeit -n 10
cuda_result = cuda_ext.layernorm(embedding)

58.1 µs ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
with torch.no_grad():
    torch_result = layer_norm(embedding)
cuda_result_welford = cuda_ext.layernorm_welford(embedding)
cuda_result = cuda_ext.layernorm_welford(embedding)
print(torch.allclose(torch_result, cuda_result, rtol=1e-03))
print(torch.allclose(cuda_result_welford, cuda_result))

True
True


In [12]:
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,  # Track CPU activity
        torch.profiler.ProfilerActivity.CUDA  # Track GPU activity (for memory and time)
    ],
    record_shapes=True,  # Records shapes of operator inputs
    profile_memory=True, # Tracks memory usage
    with_stack=True      # Captures stack traces (optional, useful for deep debugging)
) as prof:
    torch_result = layer_norm(embedding)

print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",  # Sort by GPU memory usage
    row_limit=10                       # Limit rows (useful for large reports)
))
del prof

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::empty         1.60%      36.000us         1.60%      36.000us      12.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      17.00 Kb      17.00 K

STAGE:2024-10-13 17:08:52 19861:19861 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-10-13 17:08:52 19861:19861 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-10-13 17:08:52 19861:19861 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [11]:
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,  # Track CPU activity
        torch.profiler.ProfilerActivity.CUDA  # Track GPU activity (for memory and time)
    ],
    record_shapes=True,  # Records shapes of operator inputs
    profile_memory=True, # Tracks memory usage
    with_stack=True      # Captures stack traces (optional, useful for deep debugging)
) as prof:
    cuda_result = cuda_ext.layernorm_welford(embedding)
print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",  # Sort by GPU memory usage
    row_limit=10                       # Limit rows (useful for large reports)
))
del prof

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::empty        93.47%       2.174ms        93.47%       2.174ms       2.174ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b      16.00 Kb      16.00 K

STAGE:2024-10-13 17:09:34 19914:19914 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-10-13 17:09:34 19914:19914 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-10-13 17:09:34 19914:19914 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [12]:
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,  # Track CPU activity
        torch.profiler.ProfilerActivity.CUDA  # Track GPU activity (for memory and time)
    ],
    record_shapes=True,  # Records shapes of operator inputs
    profile_memory=True, # Tracks memory usage
    with_stack=True      # Captures stack traces (optional, useful for deep debugging)
) as prof:
    cuda_result = cuda_ext.layernorm(embedding)
print(prof.key_averages().table(
    sort_by="self_cuda_memory_usage",  # Sort by GPU memory usage
    row_limit=10                       # Limit rows (useful for large reports)
))
del prof

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::empty        94.43%       2.271ms        94.43%       2.271ms       2.271ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b      16.00 Kb      16.00 K

STAGE:2024-10-13 17:09:37 19914:19914 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-10-13 17:09:37 19914:19914 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-10-13 17:09:37 19914:19914 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


# Debug

In [13]:
cuda_src = cuda_begin + r'''

template <typename T>
__device__ bool greater_function(T data1, T data2) {
    return data1 > data2;
}

template <typename T>
__device__ int max_index_function(T* data, int stride, int length) {
    T max_value = data[0];
    int max_index = 0;

    int index = 0;
    for (int i=1; i<length; ++i) {
        index += stride;
        if (greater_function<T>(data[index], max_value)) {
            max_value = data[index];
            max_index = i;
        }
    }
    return max_index;
}


template <typename T>
__global__ void argmax_kernel(void* data, char* index, int number, int stride, int length) {
    int i = threadIdx.x + blockIdx.x * blockDim.x;
    if (i < number) {
        int row = i / stride;
        int col = i % stride;
        argmax_index[i] = max_index_function<T>((T*)data + row * length * stride + col, stride, length);
    }
}

void argmax_cuda(torch::Tensor data, torch::Tensor index, int64_t number, int64_t stride, int64_t length) {
    data = data.contiguous();
    index = index.contiguous();
    int gpu_id = index.device().index();
    cudaSetDevice(gpu_id);

    block_size = 32;
    argmax_kernel<float><<<cdiv(1.0*number/block_size), block_size>>>((void*)data.data_ptr(), (char*)index.data_ptr(), number, stride, length);
    cudaDeviceSynchronize();
}

'''

In [14]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\s*\([^\)]*\))\s*\{{', src, re.MULTILINE)
    return res[0] + ';' if res else None


fname = 'argmax_cuda'
cpp_src = get_sig(fname, cuda_src)

In [15]:
print(cpp_src)

void argmax_cuda(torch::Tensor data, torch::Tensor index, int64_t number, int64_t stride, int64_t length);


In [96]:
module = load_cuda(cuda_src, cpp_src, [fname], opt=True)
def arg_max(
        matrix: typing.Union[torch.cuda.FloatTensor, torch.cuda.HalfTensor], 
        dim: int, 
        keepdim: bool=False) -> torch.cuda.CharTensor:
    
    assert matrix.dtype in [torch.float32, torch.float16], "only support float32 and float16"
    assert isinstance(matrix, torch.Tensor) and matrix.is_cuda, "only support torch.Tensor, gpu"
    assert matrix.ndim > dim, f"matrix.ndim({matrix.ndim}) should be larger than dim({dim})"
    assert dim >= 0, f"dim({dim}) should be >= 0"

    shape = [i for i in matrix.shape]
    assert shape[dim] <= 128, f"shape[dim]({shape[dim]}) should be <= 128"

    number, stride, length = 1, 1, shape[dim]
    
    for i in range(dim+1, matrix.ndim):
        stride *= shape[i]

    number = stride

    for i in range(dim):
        number *= shape[i]

    if keepdim:
        shape[dim] = 1
    else:
        shape.pop(dim)

    index_cuda = torch.zeros(shape, dtype = torch.int8, device=matrix.device)

    # torch.ops.ext.argmax(matrix, index_cuda, number, stride, length)
    module.argmax_cuda(matrix, index_cuda, number, stride, length)

    return index_cuda

RuntimeError: Error building extension 'argmax_cuda_v3': [1/3] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=argmax_cuda_v3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/miniconda3/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -O3 -Xptxas -O3 -Xcompiler -O3 -std=c++17 -c /root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu -o cuda.cuda.o 
[31mFAILED: [0mcuda.cuda.o 
/usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=argmax_cuda_v3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/miniconda3/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -O3 -Xptxas -O3 -Xcompiler -O3 -std=c++17 -c /root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu -o cuda.cuda.o 
/root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu(52): error: identifier "argmax_index" is undefined
          argmax_index[i] = max_index_function<T>((T*)data + row * length * stride + col, stride, length);
          ^

/root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu(62): error: identifier "block_size" is undefined
      block_size = 32;
      ^

/root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu(63): error: too few arguments in function call
      argmax_kernel<float><<<cdiv(1.0*number/block_size), block_size>>>((void*)data.data_ptr(), (char*)index.data_ptr(), number, stride, length);
                                                       ^

3 errors detected in the compilation of "/root/.cache/torch_extensions/py310_cu121/argmax_cuda/cuda.cu".
[2/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=argmax_cuda_v3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/miniconda3/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /root/.cache/torch_extensions/py310_cu121/argmax_cuda/main.cpp -o main.o 
ninja: build stopped: subcommand failed.


In [83]:
result = arg_max(m, dim=0)

ImportError: /root/.cache/torch_extensions/py310_cu121/argmax_cuda/argmax_cuda_v1.so: cannot open shared object file: No such file or directory