In [1]:
cuda_code_file = "./src/gpu.cu"
header_code_file = "./src/gpu.hpp"

In [3]:
with open(cuda_code_file) as f:
    cuda_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(cuda_code)


void printCudaVersion()
{
    std::cout << "CUDA Compiled version: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__ << std::endl;

    int runtime_ver;
    cudaRuntimeGetVersion(&runtime_ver);
    std::cout << "CUDA Runtime version: " << runtime_ver << std::endl;

    int driver_ver;
    cudaDriverGetVersion(&driver_ver);
    std::cout << "CUDA Driver version: " << driver_ver << std::endl;
}

__global__
void saxpy(int n, float a, float *x, float *y) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i < n) {
        y[i] = a*x[i] + y[i];
    }
}

torch::Tensor saxpy_wrapper(int n, const torch::Tensor& x, torch::Tensor y, float a) {
    saxpy<<<n, 1>>>(n, a, x.data_ptr<float>(), y.data_ptr<float>());
    std::cout <<  "Calculated saxpy\n";
    cudaDeviceSynchronize();
    return y;
}


In [4]:
with open(header_code_file) as f:
    header_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(header_code)


void printCudaVersion();

torch::Tensor saxpy_wrapper(int n, const torch::Tensor& x, torch::Tensor y, float a);



In [5]:
!rm ./build/*

In [10]:
import torch
from torch.utils.cpp_extension import load_inline

saxpy_extension = load_inline(
    name='saxpy_extension',
    cpp_sources=header_code,
    cuda_sources=cuda_code,
    functions=['saxpy_wrapper', "printCudaVersion"],
    with_cuda=True,
    verbose=True,
    extra_cuda_cflags=["-O2"],
    build_directory='./build',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)

a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda')
saxpy_extension.printCudaVersion()

CUDA Compiled version: 12.1
CUDA Runtime version: 11070
CUDA Driver version: 12020


No modifications detected for re-loaded extension module saxpy_extension, skipping build step...
Loading extension module saxpy_extension...


In [8]:
x = torch.randn((2, 3), device="cuda")
y = torch.randn((2, 3), device="cuda")
print(x)
print(y)

saxpy_extension.saxpy_wrapper(6, x, y, 2.0)

tensor([[-0.4265,  0.2755, -0.3003],
        [ 0.9430,  0.4791, -0.8424]], device='cuda:0')
tensor([[-0.3069,  0.6481,  0.8621],
        [-0.1215,  0.6083,  0.6522]], device='cuda:0')


TypeError: saxpy_wrapper(): incompatible function arguments. The following argument types are supported:
    1. (arg0: int, arg1: float, arg2: float, arg3: float) -> None

Invoked with: 6, tensor([[-0.4265,  0.2755, -0.3003],
        [ 0.9430,  0.4791, -0.8424]], device='cuda:0'), tensor([[-0.3069,  0.6481,  0.8621],
        [-0.1215,  0.6083,  0.6522]], device='cuda:0'), 2.0

In [7]:
x

tensor([[-0.7167, -1.6077, -1.3150],
        [-2.5737,  0.2456,  0.2314]], device='cuda:0')