In [1]:
from collections import namedtuple
from math import ceil

import torch
from torch.utils.cpp_extension import load_inline
import numpy as np

from cuda_utils import load_cuda_inline

%load_ext wurlitzer

In [2]:
dim3 = namedtuple("dim3", ["x", "y", "z"], defaults=(1, 1))

## Matrix multiplication
### Python

In [3]:
torch.ones((2,3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [4]:
a = torch.rand(20, 50)
b = torch.rand(50, 70)

In [5]:
def matrix_multiplication_math(blockidx, blockdim, threadidx, a, b, out, h, w, k):
    c = blockidx.x * blockdim.x + threadidx.x
    r = blockidx.y * blockdim.y + threadidx.y
    if c >= w or r >= h: return
    o = 0
    for i in range(k):
        o += a[r * k + i] * b[c + w * i]
    out[r, c] = o

def grid_2d_kernel_launch(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x):
                    f(dim3(i1, i0), threads, dim3(j1, j0), *args)

def matrix_multiplication(a, b):
    h, k = a.shape
    k2, w = b.shape
    assert k == k2
    output = torch.empty(h, w)
    threads_per_block = dim3(16, 16)
    blocks = dim3(int(ceil(w / threads_per_block.x)), int(ceil(h / threads_per_block.y)))
    grid_2d_kernel_launch(matrix_multiplication_math, blocks, threads_per_block, a.flatten(), b.flatten(), output, h, w, k)
    return output

In [6]:
%time out = matrix_multiplication(a, b)

CPU times: user 1.1 s, sys: 0 ns, total: 1.1 s
Wall time: 1.1 s


In [7]:
torch.allclose(out, a @ b)

True

### CUDA

In [8]:
cuda_src = """
__global__ void matrix_multiplication_math(float *a, float *b, float *out, int h, int w, int k) {
    int c = blockIdx.x * blockDim.x + threadIdx.x;
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    if (c >= w || r >= h) return;
    float o = 0;
    for (int i = 0; i < k; i++) {
        o += a[r * k + i] * b[c + w * i];
    }
    out[r * w + c] = o;
}

torch::Tensor matrix_multiplication(torch::Tensor a, torch::Tensor b) {
    int h = a.size(0);
    int k = a.size(1);
    int w = b.size(1);
    TORCH_CHECK(k == b.size(0));
    auto output = torch::empty({h, w}, a.options());
    dim3 threads_per_block(16, 16);
    dim3 blocks(ceil(w / (float)threads_per_block.x), ceil(h / (float)threads_per_block.y));
    matrix_multiplication_math<<<blocks, threads_per_block>>>(
        a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), h, w, k
    );
    return output;
}
"""

In [9]:
cpp_src = "torch::Tensor matrix_multiplication(torch::Tensor a, torch::Tensor b);"
cuda_module = load_cuda_inline(cuda_src, cpp_src, ["matrix_multiplication"])

In [10]:
a_cuda = a.contiguous().cuda()
b_cuda = b.contiguous().cuda()

In [11]:
torch.allclose(a_cuda @ b_cuda, cuda_module.matrix_multiplication(a_cuda, b_cuda))

True

In [12]:
%%timeit
ouput_cuda = cuda_module.matrix_multiplication(a_cuda, b_cuda).cpu()

29.1 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%timeit (a_cuda @ b_cuda).cpu()

33.9 µs ± 200 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [14]:
a = torch.rand(4096, 256)
b = torch.rand(256, 2048)
a_np = a.numpy()
b_np = b.numpy()
a_cuda = a.contiguous().cuda()
b_cuda = b.contiguous().cuda()

In [15]:
%timeit a_np @ b_np

109 ms ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%timeit cuda_module.matrix_multiplication(a_cuda, b_cuda).cpu()

21.2 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
%timeit (a_cuda @ b_cuda).cpu()

17.7 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
