In [2]:
# Load cuda kernel
from pathlib import Path

In [3]:
# Utils
import torch
import re
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None


def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

In [23]:
# Run
cuda_source_path = "./matmul.cu"
fname = "matmul_tiled_sqr"
cuda_source = Path(cuda_source_path).read_text()
cpp_source = get_sig(fname, cuda_source)
# print(cpp_source)
module = load_cuda(cuda_source, cpp_source, funcs=[fname])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [24]:
N, K, M = 16, 16, 16
A = torch.randn(N, K).contiguous().cuda()
B = torch.randn(K, M).contiguous().cuda()
C = module.matmul_tiled_sqr(A, B)
C_torch = A@B
C - C_torch

tensor([[-8.3780e+00,  8.2713e+00,  3.7221e+00,  1.6051e+01,  3.2614e+00,
         -2.8923e+00, -3.4417e+00,  4.1901e+00,  3.2388e+00,  1.6170e+01,
         -8.2670e+00, -5.3273e+00,  1.2837e+01,  6.5150e-01,  1.9224e+00,
         -8.1438e+00],
        [ 7.8391e+00, -1.4732e+00,  2.1935e+00, -1.2336e+01, -6.3537e+00,
          2.0819e+00,  3.4461e+00, -5.5568e+00, -1.0510e+01, -8.7675e+00,
          5.2813e+00, -5.3011e+00, -9.9769e+00,  5.1394e+00, -5.6581e+00,
          8.6979e+00],
        [ 9.9978e+00, -2.9547e+00, -2.1508e+00,  4.6863e+00, -1.7223e+00,
         -2.6635e-01, -3.4982e+00, -3.7122e+00,  3.3484e+00, -7.2521e-01,
          9.5317e+00,  7.8237e+00, -8.3222e+00, -2.4237e+00,  2.2107e+00,
          3.7652e+00],
        [ 3.0819e+00,  1.0950e-01,  8.6123e-01,  6.9057e+00,  4.4520e+00,
          1.2826e+00, -3.9218e+00,  4.7946e-01,  3.7538e+00,  1.9237e+01,
         -9.3564e+00,  9.3188e+00, -4.1060e+00,  1.1693e+00,  1.9593e+00,
         -7.9771e+00],
        [-7.8652e+00

In [17]:
C_torch = A@B
C - C_torch

tensor([[ 5.7537e+00, -2.4631e-03,  4.1649e+00, -4.1259e+00,  4.4670e+00,
         -1.5951e+00, -6.8162e+00,  1.0624e+01,  1.1907e+00, -2.6340e+00,
         -1.2871e+00, -2.4438e+00, -6.2914e+00, -2.7668e+00, -6.1911e+00,
          9.6855e-01],
        [ 2.0255e-01,  3.0809e+00, -1.0944e+00,  4.6841e+00, -3.3999e+00,
         -2.5092e+00, -2.6668e+00, -1.2059e+00, -4.2102e+00, -4.2659e+00,
          2.4964e+00, -3.7057e+00,  1.0011e+01, -1.2592e+00, -1.6613e+00,
         -7.7185e+00],
        [-7.7027e-02,  1.8407e+00,  2.3225e+00, -5.9372e+00, -4.0125e-01,
          6.6849e-01, -1.5090e+00, -2.1927e+00, -3.0730e+00,  2.3212e+00,
         -3.6026e+00,  1.4466e+00, -2.8881e+00,  3.5854e+00, -3.7300e+00,
         -5.6367e-02],
        [-2.2191e-01,  3.3889e+00, -1.2290e+00,  5.8907e+00, -1.6628e+00,
         -9.0154e-01,  2.1356e+00,  2.6677e+00, -9.1268e-03, -2.8330e+00,
          2.8663e+00,  7.4463e-02,  2.1077e+00, -6.6448e+00, -8.1448e-01,
         -5.8828e+00],
        [ 4.3732e+00