In [1]:
from collections import namedtuple
from math import ceil

import torch
from torch.utils.cpp_extension import load_inline
import numpy as np

from cuda_utils import load_cuda_inline

%load_ext wurlitzer

In [2]:
dim3 = namedtuple("dim3", ["x", "y", "z"], defaults=(1, 1))

## Tiled matrix multiplication
Optimized matrix multiplication by processing multiplication tile-wise to take advantage of shared memory. Tile size is determined dynamically based on the available CUDA device, but the code is precompiled using multiple tile sizes to enable the use of static shared memory.

In [3]:
a = torch.rand(700, 400)
b = torch.rand(400, 500)

In [4]:
with open("tiled_matmul.cu", "r") as f:
    cuda_src = f.read()

In [5]:
cpp_src = "torch::Tensor tiled_matrix_multiplication(torch::Tensor a, torch::Tensor b);"
cuda_module = load_cuda_inline(cuda_src, cpp_src, ["tiled_matrix_multiplication"])

In [6]:
a_cuda = a.contiguous().cuda()
b_cuda = b.contiguous().cuda()
out_torch = (a_cuda @ b_cuda).cpu()
out_module = cuda_module.tiled_matrix_multiplication(a_cuda, b_cuda).cpu()
torch.allclose(out_torch, out_module)

True

In [7]:
%%timeit
ouput_cuda = cuda_module.tiled_matrix_multiplication(a_cuda, b_cuda).cpu()

859 µs ± 28.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
%timeit (a_cuda @ b_cuda).cpu()

273 µs ± 5.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
a = torch.rand(4096, 256)
b = torch.rand(256, 2048)
a_np = a.numpy()
b_np = b.numpy()
a_cuda = a.contiguous().cuda()
b_cuda = b.contiguous().cuda()

In [10]:
%timeit a_np @ b_np

110 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit cuda_module.tiled_matrix_multiplication(a_cuda, b_cuda).cpu()

23.3 ms ± 2.78 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit (a_cuda @ b_cuda).cpu()

17.5 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Should be faster on more balanced matrices (less padding due to square tiles)

In [13]:
a = torch.rand(4096, 3768)
b = torch.rand(3768, 4096)
a_np = a.numpy()
b_np = b.numpy()
a_cuda = a.contiguous().cuda()
b_cuda = b.contiguous().cuda()

In [14]:
%timeit a_np @ b_np

291 ms ± 18.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%timeit cuda_module.tiled_matrix_multiplication(a_cuda, b_cuda).cpu()

85.8 ms ± 291 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%timeit (a_cuda @ b_cuda).cpu()

41.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
