## Setup

In [None]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [None]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(0,0))

In [None]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=0)

In [None]:
d.x,d.y

(2, 3)

In [None]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [None]:
sys.path.insert(0, '..')

In [None]:
from utils import show_img,load_cuda,cuda_begin,cdiv

In [None]:
%load_ext wurlitzer

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'
torch.manual_seed(42);

In [None]:
m1 = torch.rand(50_000, 784)
m1s = m1[:8]
m2 = torch.rand(784,10)

## Reminder

### 2d Python kernel

In [None]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

In [None]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [None]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [None]:
torch.isclose(matmul_2d(m1s, m2), m1s@m2).all()

tensor(True)

### CUDA

In [None]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul'

In [None]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [None]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [None]:
module.matmul(m1c,m2c).shape

torch.Size([50000, 10])

In [None]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

## Shared mem

### Python

In [None]:
a = torch.zeros(5)
b,c = a[:3],a[3:]

In [None]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

In [None]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0), threads, shared, *args, **kwargs)

In [None]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)

In [None]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    for ph in range(int(math.ceil(k/tw))):
        idx = ph*tw
        for ty in range(blockDim.y):
            for tx in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + ty, blockIdx.x*blockDim.x + tx
                ms[ty*tw+tx] = m[ tx+idx + r*k] if r<h and idx+tx<k else 0.
                ns[ty*tw+tx] = n[(ty+idx)*w +c] if c<w and idx+ty<k else 0.

        for ty in range(blockDim.y):
            for tx in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + ty, blockIdx.x*blockDim.x + tx
                for i in range(tw):
                    if r*w+c<len(out): out[r*w+c] += ms[ty*tw+i] * ns[tw*i+tx]

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
m1s.shape, m2.shape

(torch.Size([8, 784]), torch.Size([784, 10]))

In [None]:
torch.isclose(matmul_2d(m1s, m2, tw=16), m1s@m2).all()

tensor(True)

### Python

In [None]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)

In [None]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    def get_rc(ty, tx): return blockIdx.y*blockDim.y + ty, blockIdx.x*blockDim.x + tx

    def fill_shared_tk(ty, tx, ph):
        r,c = get_rc(ty, tx)
        ms[ty*tw+tx] = m[ tx + ph*tw + r*k] if r<h and (ph*tw+tx)<k else 0.
        ns[ty*tw+tx] = n[(ty + ph*tw)*w +c] if c<w and (ph*tw+ty)<k else 0.

    def dotprod_tk(ty, tx):
        r,c = get_rc(ty, tx)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += ms[ty*tw+i] * ns[tw*i+tx]

    for ph in range(int(math.ceil(k/tw))):
        run_threads(fill_shared_tk, blockDim, ph)
        run_threads(dotprod_tk, blockDim)

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
m1s.shape, m2.shape

(torch.Size([8, 784]), torch.Size([784, 10]))

In [None]:
torch.isclose(matmul_2d(m1s, m2, tw=16), m1s@m2).all()

tensor(True)

### Python threads

In [None]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [None]:
def g(x, sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)

In [None]:
num = 3
sb = Barrier(num)
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-1
-2
10
20
30


In [None]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for t in threads: t.start()
            for t in threads: t.join()

In [None]:
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tx,ty = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + ty
    c = blockIdx.x*blockDim.x + tx

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.
    for ph in range(cdiv(k,tw)):
        ms[ty*tw+tx] = m[ tx + ph*tw + r*k] if r<h and (ph*tw+tx)<k else 0.
        ns[ty*tw+tx] = n[(ty + ph*tw)*w +c] if c<w and (ph*tw+ty)<k else 0.
        syncb.wait()
        for i in range(tw): p += ms[ty*tw+i] * ns[tw*i+tx]
        syncb.wait()

    if (r<h and c<w): out[r*w + c] = p

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
%%time
torch.isclose(matmul_2d(m1s, m2, tw=8), m1s@m2).all()

CPU times: user 5.57 s, sys: 4.91 s, total: 10.5 s
Wall time: 5.03 s


tensor(True)

### CUDA

Code auto-generated by ChatGPT 4, using the following prompt:

> Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove `blockIdx, threadIdx, blockDim, shared` from the argument list, since they're already provided by CUDA. Change `syncb.wait()` to `__syncthreads__`. Use `extern __shared__ float shared[]` to create the `shared` array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.

The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming `shared` to `ms`).

In [None]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int r = blockIdx.y * blockDim.y + ty;
    int c = blockIdx.x * blockDim.x + tx;

    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[ty*tw + tx] = r<h && idx+tx<k ? m[ tx+idx + r*k ] : 0.0;
        ns[ty*tw + tx] = c<w && idx+ty<k ? n[(ty+idx)*w + c] : 0.0;
        __syncthreads();
        for (int i = 0; i < tw; ++i) p += ms[ty * tw + i] * ns[tw * i + tx];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [None]:
cuda_src += r'''
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    cudaDeviceProp devProp;
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize);
    int TW = std::sqrt(maxThreads);
    printf("Shared per block: %zu bytes; tile width: %zu\n", size, TW);
    
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    printf("blocks.x: %u blocks.y: %u\n", blocks.x, blocks.y);
    matmul_k<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul_dyn'

In [None]:
cpp_src = get_sig(fname, cuda_src)

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()

Shared per block: 8192 bytes; tile width: 32
blocks.x: 1 blocks.y: 1563


tensor(True, device='cuda:0')

## Numba

In [None]:
from numba import cuda
from numba.cuda import as_cuda_array as ca

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tx,ty = tid.x,tid.y
    r,c = cbi.y * cbd.y + ty, cbi.x * cbd.x + tx
    h,k  = m.shape
    k2,w = n.shape
    
    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[ty*tw+tx] = m[r, tx+idx] if r<h and idx+tx<k else 0.
        ns[ty*tw+tx] = n[ty+idx, c] if c<w and idx+ty<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[ty*tw+i] * ns[i*tw+tx]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [None]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype).cuda()
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [None]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')