In [1]:
cuda_code_file = "./src/gpu.cu"
header_code_file = "./src/gpu.hpp"

In [3]:
with open(cuda_code_file) as f:
    cuda_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(cuda_code)


void printCudaVersion()
{
    std::cout << "CUDA Compiled version: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__ << std::endl;

    int runtime_ver;
    cudaRuntimeGetVersion(&runtime_ver);
    std::cout << "CUDA Runtime version: " << runtime_ver << std::endl;

    int driver_ver;
    cudaDriverGetVersion(&driver_ver);
    std::cout << "CUDA Driver version: " << driver_ver << std::endl;
}

__global__
void saxpy(int n, float a, float *x, float *y) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i < n) {
        y[i] = a*x[i] + y[i];
    }
}

torch::Tensor saxpy_wrapper(int n, const torch::Tensor& x, torch::Tensor y, float a) {
    saxpy<<<n, 1>>>(n, a, x.data_ptr<float>(), y.data_ptr<float>());
    std::cout <<  "Calculated saxpy\n";
    cudaDeviceSynchronize();
    return y;
}


In [4]:
with open(header_code_file) as f:
    header_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(header_code)


void printCudaVersion();

torch::Tensor saxpy_wrapper(int n, const torch::Tensor& x, torch::Tensor y, float a);



In [5]:
!rm ./build/*

In [6]:
import torch
from torch.utils.cpp_extension import load_inline

saxpy_extension = load_inline(
    name='saxpy_extension',
    cpp_sources=header_code,
    cuda_sources=cuda_code,
    functions=['saxpy_wrapper', "printCudaVersion"],
    with_cuda=True,
    verbose=True,
    extra_cuda_cflags=["-O2"],
    build_directory='./build',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)

a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda')
saxpy_extension.printCudaVersion()

Detected CUDA files, patching ldflags
Emitting ninja build file ./build/build.ninja...
Building extension module saxpy_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=saxpy_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/ksharma/anaconda3/envs/cuda-learn/lib/python3.12/site-packages/torch/include -isystem /home/ksharma/anaconda3/envs/cuda-learn/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/ksharma/anaconda3/envs/cuda-learn/lib/python3.12/site-packages/torch/include/TH -isystem /home/ksharma/anaconda3/envs/cuda-learn/lib/python3.12/site-packages/torch/include/THC -isystem /home/ksharma/anaconda3/envs/cuda-learn/include -isystem /home/ksharma/anaconda3/envs/cuda-learn/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/ksharma/dev/git/cuda-learn/build/main.cpp -o main.o 
[2/3] /home/ksharma/anaconda3/envs/cuda-learn/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=saxpy_extension -DT

Loading extension module saxpy_extension...


CUDA Runtime version: 11070
CUDA Driver version: 12020


In [7]:
x = torch.randn((2, 3), device="cuda")
y = torch.randn((2, 3), device="cuda")
print(x)
print(y)

saxpy_extension.saxpy_wrapper(6, x, y, 2.0)

tensor([[ 0.8709, -0.0807, -1.8766],
        [-0.2632,  0.7398, -0.9265]], device='cuda:0')
tensor([[-1.6356,  0.4335,  0.0403],
        [-0.9111,  2.2448,  0.1985]], device='cuda:0')
Calculated saxpy


tensor([[ 0.1062,  0.2721, -3.7128],
        [-1.4375,  3.7244, -1.6544]], device='cuda:0')