In [6]:
# Load cuda kernel
from pathlib import Path

In [7]:
# Utils
import torch
import re
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None


def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

In [8]:
# Run
cuda_source_path = "./matmul.cu"
fname = "matmul_tiled"
cuda_source = Path(cuda_source_path).read_text()
cpp_source = get_sig(fname, cuda_source)
# print(cpp_source)
module = load_cuda(cuda_source, cpp_source, funcs=[fname])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [11]:
N, K, M = 16, 36, 32
A = torch.randn(N, K).contiguous().cuda()
B = torch.randn(K, M).contiguous().cuda()
C = getattr(module, fname)(A, B)
C_torch = A@B
assert torch.allclose(C, C_torch, atol=1e-4) 

# TODO
- assert Different sizes and non square
- run profiling and see mem/compute bound 
- write naive mat mul
- plan flash attn
- write readme