### set up

In [36]:
import math,os, sys,torch,re, numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [6]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [8]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [10]:
d.x,d.y,d.z

(2, 3, 1)

In [11]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [12]:
sys.path.insert(0, '..')

In [13]:
from utills import show_img,load_cuda, cuda_begin, cdiv

In [14]:
%load_ext wurlitzer

In [15]:
torch.manual_seed(42);



In [17]:
m1 = torch.randn(5120,256)
m1s = m1[:4]
m2 = torch.randn(256, 5120)
m2s = m2[:,:4]

#### Python 2D

In [25]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0),dim3(j1,j0), threads, *args)

In [29]:
def matmal_bk(blockIdx, threadIdx, blockDim, m,n, out, h,w,k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [30]:
def matmal2d(m,n):
    h,k = m.shape
    k2,w = n.shape
    assert k2==k, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmal_bk, blocks, tpb,
                m.flatten(), n.flatten(), output.flatten(),h,w,k)
    return output

In [31]:
torch.isclose(matmal2d(m1s,m2s), m1s@m2s).all()

tensor(True)

In [39]:
cuda_src = cuda_begin + r'''
__global__ void matmal_k(float* m,float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;
    
    if (r>=h || c>=w) return;
    float o = 0;
    for(int i=0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n){
    CHECK_INPUT(m);CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "SIZE MISMATCH!");
    auto output = torch::zeros({h,w}, m.options());
    
    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmal_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [40]:
fname = 'matmul'

In [41]:
def get_sig(fname,src):
    res = re.findall(rf'^(.+\s+{fname}.*?)\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [42]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [43]:
module = load_cuda(cuda_src, cpp_src, [fname])



In [44]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()



In [50]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [51]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [52]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

33.3 ms ± 95.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
