In [1]:
import torch
from torch.utils.cpp_extension import load_inline

# Approximate gelu as a fusion example

In [2]:
def gelu(x):
    return 0.5 * x * (1 + torch.tanh((2 / torch.pi)**0.5 * (x + 0.044715 * x**3)))

In [3]:
x = torch.randn(1024, 1024, device='cuda')

In [4]:
gelu(x) - torch.nn.functional.gelu(x, approximate='tanh')

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [5]:
%timeit gelu(x); torch.cuda.synchronize()
%timeit torch.nn.functional.gelu(x, approximate='tanh'); torch.cuda.synchronize()

118 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
20.3 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Kind of slow. Why ?

In [6]:
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [7]:
cuda_src = cuda_begin + r"""
__global__ void my_gelu_kernel(float* __restrict__ out, float* __restrict__ in, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i >= n) return;
    float x = in[i];
    out[i] = 0.5f * x * (1.f + tanhf(sqrtf(2.f / 3.141592653589793f) * (x + 0.044715f * (x * x * x))));
}

torch::Tensor my_gelu_out(torch::Tensor output, torch::Tensor const& in) {
    CHECK_INPUT(in);
    int n = in.numel();
    TORCH_CHECK((output.sizes() == in.sizes()) || (output.device() == in.device()) || (output.scalar_type() == in.scalar_type()));

    int threads = 256;
    my_gelu_kernel<<<cdiv(n, threads), threads>>>(
        output.data_ptr<float>(), in.data_ptr<float>(), n);
    
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

torch::Tensor my_gelu(torch::Tensor const& in) {
    CHECK_INPUT(in);
    auto output = torch::empty_like(in);
    my_gelu_out(output, in);
    return output;
}
"""

cpp_src = """
torch::Tensor my_gelu(torch::Tensor const& in);
torch::Tensor my_gelu_out(torch::Tensor output, torch::Tensor const& in);
"""

In [8]:
gelu_module = load_inline(
    "test_ext_gelu", cpp_src, cuda_src,
    functions=['my_gelu', 'my_gelu_out'],
    extra_cuda_cflags=['--ptxas-options=-v'])

In [9]:
(gelu_module.my_gelu(x) - gelu(x)).abs().max()

tensor(2.3842e-07, device='cuda:0')

In [10]:
%timeit gelu_module.my_gelu(x); torch.cuda.synchronize()

23.5 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Measure launch latency

In [11]:
cuda_src = cuda_begin + r'''
__global__ void my_empty_kernel(float* __restrict__ out, float* __restrict__ in, int n) {
}

torch::Tensor my_empty_out(torch::Tensor output, torch::Tensor const& in) {
    CHECK_INPUT(in);
    int n = in.numel();
    TORCH_CHECK((output.sizes() == in.sizes())  || (output.device() == in.device()) || (output.scalar_type() == in.scalar_type()));
    
    int threads = 256;
    my_empty_kernel<<<cdiv(n, threads), threads>>>(
        output.data_ptr<float>(), in.data_ptr<float>(), n);
        
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

torch::Tensor my_empty(torch::Tensor const& in) {
    CHECK_INPUT(in);
    auto output = torch::empty_like(in);
    my_empty_out(output, in);
    return output;
}
'''

cpp_src = """
torch::Tensor my_empty(torch::Tensor const& in);
torch::Tensor my_empty_out(torch::Tensor output, const torch::Tensor& in);
"""

In [12]:
empty_module = load_inline(
    "test_ext_empty", cpp_src, cuda_src,
    functions=['my_empty', 'my_empty_out'],
    extra_cuda_cflags=['--ptxas-options=-v'])

In [13]:
%timeit empty_module.my_empty_out(x, x); torch.cuda.synchronize()

with torch.profiler.profile() as prof:
    for i in range(10_000):
        empty_module.my_empty_out(x, x)
        torch.cuda.synchronize()
print(prof.key_averages().table())

10 µs ± 13 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


STAGE:2024-03-13 18:40:24 678474:678474 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-03-13 18:40:24 678474:678474 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-03-13 18:40:24 678474:678474 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        cudaLaunchKernel        74.39%      33.327ms        74.39%      33.327ms       3.333us       0.000us         0.00%       0.000us       0.000us         10000  
    my_empty_kernel(float*, float*, int)         0.00%       0.000us         0.00%       0.000us       0.000us      40.244ms       100.00%      40.244ms       4.024us         10000  
                   cudaDeviceSynchronize        25.61%      11.471ms        25.61%  

# Tiled Matmul

In [14]:
cuda_src = cuda_begin + r"""
__global__ void simple_matmul_kernel(float* __restrict__ a, float* __restrict__ b, float* __restrict__ out, int m, int n, int k) {
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int c = blockIdx.x * blockDim.x + threadIdx.x;

    if (r >= m || c >= n) return;
    float tmp = 0.f;
    for (int i = 0; i < k; ++i) {
        tmp += a[r * k + i] * b[i * n + c];
    }
    out[r * n + c] = tmp;
}

torch::Tensor simple_matmul(torch::Tensor a, torch::Tensor b) {
    CHECK_INPUT(a);
    CHECK_INPUT(b);
    int m = a.size(0);
    int n = b.size(1);
    int k = a.size(1);
    TORCH_CHECK(k == b.size(0), "Size mismatch!");

    auto output = torch::zeros({m, n}, a.options());

    dim3 threads_per_block{16, 16};
    dim3 blocks{cdiv(n, threads_per_block.x), cdiv(m, threads_per_block.y)};

    simple_matmul_kernel<<<blocks, threads_per_block>>>(
        a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), m, n, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();

    return output;
}
"""

cpp_src = "torch::Tensor simple_matmul(torch::Tensor a, torch::Tensor b);"

In [15]:
simple_matmul_module = load_inline(
    "test_ext_simple_matmul", cpp_src, cuda_src, 
    functions=['simple_matmul'], extra_cuda_cflags=['--ptxas-options=-v'])

In [16]:
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
%timeit simple_matmul_module.simple_matmul(a, b)

(simple_matmul_module.simple_matmul(a, b) - a@b).abs().max()

1.11 ms ± 1.58 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


tensor(0.0002, device='cuda:0')

In [17]:
cuda_src = cuda_begin + r"""
constexpr int TILE_SIZE = 16;

__global__ void tiled_matmul_kernel(float* out, float* M, float* N, int m, int n, int k) {
    __shared__ float M_tile[TILE_SIZE][TILE_SIZE];
    __shared__ float N_tile[TILE_SIZE][TILE_SIZE];
    
    // idxes into tile
    int ir = threadIdx.y;
    int ic = threadIdx.x;
    
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int c = blockIdx.x * blockDim.x + threadIdx.x;
    
    // note: cannot just exit if we want to do padding!
    
    float res = 0.0f;
    for (int K_tileidx = 0; K_tileidx < (k + TILE_SIZE -1) / TILE_SIZE; K_tileidx++) {
        // note how threadIdx.x is the fastes moving bit --> coalesced memory access
        M_tile[ir][ic] = (((r < m) && (K_tileidx * TILE_SIZE + ic < k)) ? M[r * k + K_tileidx * TILE_SIZE + ic] : 0.f);
        N_tile[ir][ic] = ((((K_tileidx * TILE_SIZE + ir) < k) && (c < n)) ? N[(K_tileidx * TILE_SIZE + ir) * n + c] : 0.f);
        __syncthreads();
        for (int idx = 0; idx < TILE_SIZE; idx++) {
            res += M_tile[ir][idx] * N_tile[idx][ic];
        }
        __syncthreads(); // important! (why?)
    }
    if ((r < m) && (c < n)) {
        out[r * n + c] = res;
    }
}

torch::Tensor tiled_matmul(torch::Tensor const& a, torch::Tensor const& b) {
    CHECK_INPUT(a); CHECK_INPUT(b);
    int m = a.size(0);
    int n = b.size(1);
    int k = a.size(1);
    TORCH_CHECK(k==b.size(0), "Size mismatch");
    
    auto output = torch::empty({m, n}, a.options());

    dim3 tpb{TILE_SIZE, TILE_SIZE};
    dim3 blocks{cdiv(n, tpb.x), cdiv(m, tpb.y)};
    tiled_matmul_kernel<<<blocks, tpb>>>(
        output.data_ptr<float>(), a.data_ptr<float>(), b.data_ptr<float>(), m, n, k);

    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

"""
cpp_src = """
torch::Tensor tiled_matmul(torch::Tensor const& m, torch::Tensor const& n);
"""

In [18]:
tiled_matmul_module = torch.utils.cpp_extension.load_inline(
    "test_ext_tiled_matmul", cpp_src, cuda_src, 
    functions=['tiled_matmul'], extra_cuda_cflags=['--ptxas-options=-v'])

In [None]:
%timeit tiled_matmul_module.tiled_matmul(a, b)

In [None]:
aa = torch.randn(500, 200, device="cuda")
bb = torch.randn(200, 1000, device="cuda")

(tiled_matmul_module.tiled_matmul(aa, bb) - aa@bb).abs().max()