## Setup

In [1]:
!pip install ninja



In [2]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [3]:
import torch
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

def show_img(x, figsize=(4,3), **kwargs):
    "Display HW or CHW format image `x`"
    plt.figure(figsize=figsize)
    plt.axis('off')
    if len(x.shape)==3: x = x.permute(1,2,0)  # CHW -> HWC
    plt.imshow(x.cpu(), **kwargs)

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b


In [4]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [5]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [6]:
d.x,d.y

(2, 3)

In [7]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [8]:
sys.path.insert(0, '..')

In [9]:
%load_ext wurlitzer

In [10]:
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
torch.manual_seed(42);

In [11]:
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256,5120)
m2s = m2[:,:4]

## Reminder

### 2d Python kernel

In [12]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

In [13]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [14]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [15]:
torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()

tensor(True)

### CUDA

In [16]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [17]:
fname = 'matmul'

In [18]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [19]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [20]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [21]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [22]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [23]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [24]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

28.8 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


When I removed the call to the kernel itself, it took around 50 µs (0.05 ms) to run, so that's the overhead of the call on my machine.

## Shared mem

### Python

In [25]:
a = torch.zeros(5)
b,c = a[:3],a[3:]

In [26]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

In [27]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0), threads, shared, *args, **kwargs)

In [28]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    for ph in range(cdiv(k,tw)):
        idx = ph*tw
        # fill shared
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                ms[tr*tw+tc] = m[ tc+idx + r*k] if r<h and idx+tc<k else 0.
                ns[tr*tw+tc] = n[(tr+idx)*w +c] if c<w and idx+tr<k else 0.

        # do dotprods from shared
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                for i in range(tw):
                    if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [29]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [30]:
m1s.shape, m2.shape

(torch.Size([4, 256]), torch.Size([256, 5120]))

In [31]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python run_threads

In [32]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)

In [33]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    def get_rc(tr, tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc

    def fill_shared_tk(tr, tc, ph):
        r,c = get_rc(tr, tc)
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.

    def dotprod_tk(tr, tc):
        r,c = get_rc(tr, tc)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

    for ph in range(int(math.ceil(k/tw))):
        run_threads(fill_shared_tk, blockDim, ph)
        run_threads(dotprod_tk, blockDim)

In [34]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [35]:
m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [36]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python threads

In [37]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [38]:
def g(x, sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)

In [39]:
num = 3
sb = Barrier(num)
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-2
-1
30
10
20


In [40]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for tr in threads: tr.start()
            for tr in threads: tr.join()

In [41]:
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + tr
    c = blockIdx.x*blockDim.x + tc

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.
    for ph in range(cdiv(k,tw)):
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.
        syncb.wait()
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]
        syncb.wait()

    if (r<h and c<w): out[r*w + c] = p

In [42]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [43]:
torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()

tensor(True)

### CUDA dynamic shared

Code auto-generated by ChatGPT 4, using the following prompt:

> Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove `blockIdx, threadIdx, blockDim, shared` from the argument list, since they're already provided by CUDA. Change `syncb.wait()` to `__syncthreads`. Use `extern __shared__ float shared[]` to create the `shared` array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.

The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming `shared` to `ms`).

In [44]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [45]:
cuda_src += r'''
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    /*
    cudaDeviceProp devProp;
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize);
    int TW = std::sqrt(maxThreads);
    */
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);

    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [46]:
fname = 'matmul_dyn'

In [47]:
cpp_src = get_sig(fname, cuda_src)

In [48]:
module_matmul_dyn = load_cuda(cuda_src, cpp_src, [fname], opt=True)

In [49]:
torch.isclose(module_matmul_dyn.matmul_dyn(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [50]:
%%timeit -n 10
module_matmul_dyn.matmul_dyn(m1c,m2c)
torch.cuda.synchronize()

25.3 ms ± 2.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### CUDA static shared

In [51]:
cuda_src = cuda_begin + r'''
constexpr int tw = 16;

__global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {
    __shared__ float ms[tw][tw], ns[tw][tw];
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    float p=0.0f;
    for (int ph=0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr][tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr][tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr][i] * ns[i][tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}

torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    dim3 tpb(tw,tw);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_ks<<<blocks,tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [52]:
fname = 'matmul_static'
cpp_src = get_sig(fname, cuda_src)
matmul_static = load_cuda(cuda_src, cpp_src, [fname])
torch.isclose(matmul_static.matmul_static(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [53]:
%%timeit -n 10
matmul_static.matmul_static(m1c,m2c)
torch.cuda.synchronize()

20.4 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Static shared with linear index

The code is modified from the dynamic shared memory.
* Change from `extern __shared__ float ms[];` to `__shared__ float ms[16*16*2];`
* Change from 
```    
matmul_k<<<blocks,tpb,size>>>(
m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
```
to 
```
matmul_k<<<blocks,tpb>>>(
m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
```


In [54]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    __shared__ float ms[16*16*2];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [55]:
cuda_src += r'''
torch::Tensor matmul_dyn_modified(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    /*
    cudaDeviceProp devProp;
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize);
    int TW = std::sqrt(maxThreads);
    */
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);

    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks,tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [56]:
fname = 'matmul_dyn_modified'
cpp_src = get_sig(fname, cuda_src)
module_matmul_dyn_modified = load_cuda(cuda_src, cpp_src, [fname])
torch.isclose(module_matmul_dyn_modified.matmul_dyn_modified(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [57]:
%%timeit -n 10
module_matmul_dyn_modified.matmul_dyn_modified(m1c,m2c)
torch.cuda.synchronize()

25.8 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Experiment Three Variants

In [58]:
m1 = torch.rand(5120, 256)
m2 = torch.rand(256, 5120)


In [59]:
m1c,m2c = m1.clone().contiguous().cuda(),m2.clone().contiguous().cuda()  ## always make a copy for the input of matmul function

### Dynamic shared memory

In [60]:
%%timeit -n 10
module_matmul_dyn.matmul_dyn(m1c,m2c)
torch.cuda.synchronize()

24.5 ms ± 772 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Static shared memory

In [61]:
m1c,m2c = m1.clone().contiguous().cuda(),m2.clone().contiguous().cuda()

In [62]:
%%timeit -n 10
matmul_static.matmul_static(m1c,m2c)
torch.cuda.synchronize()

19.1 ms ± 676 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Static shared memory with linear index [simple change from dynamic shared memory code]

In [63]:
m1c,m2c = m1.clone().contiguous().cuda(),m2.clone().contiguous().cuda()

In [64]:
%%timeit -n 10
module_matmul_dyn_modified.matmul_dyn_modified(m1c,m2c)
torch.cuda.synchronize()

24.5 ms ± 742 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Numba

In [65]:
from numba import cuda
from numba.cuda import as_cuda_array as ca

In [66]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [67]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [68]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [69]:
%%timeit -n 10
matmul_2d_numba(m1c,m2c)
torch.cuda.synchronize()

56.9 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
