### set up

In [11]:
import math,os, sys,torch,re, numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

### cuda setup for paperspace
* pip install --disable-pip-version-check --root-user-action=ignore wurlitzer ninja

In [12]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [13]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [14]:
d.x,d.y,d.z

(2, 3, 1)

In [15]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [16]:
sys.path.insert(0, '..')

In [17]:
from utills import show_img,load_cuda, cuda_begin, cdiv

In [18]:
%load_ext wurlitzer

In [19]:
torch.manual_seed(42);



In [58]:
m1 = torch.randn(5120,256)
m1s = m1[:4]
m2 = torch.randn(256, 5120)
m2s = m2[:,:4]

#### Python 2D

In [21]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0),dim3(j1,j0), threads, *args)

In [22]:
def matmal_bk(blockIdx, threadIdx, blockDim, m,n, out, h,w,k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [23]:
def matmal2d(m,n):
    h,k = m.shape
    k2,w = n.shape
    assert k2==k, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmal_bk, blocks, tpb,
                m.flatten(), n.flatten(), output.flatten(),h,w,k)
    return output

In [24]:
torch.isclose(matmal2d(m1s,m2s), m1s@m2s).all()

tensor(False)

In [25]:
cuda_src = cuda_begin + r'''
__global__ void matmal_k(float* m,float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;
    
    if (r>=h || c>=w) return;
    float o = 0;
    for(int i=0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n){
    CHECK_INPUT(m);CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "SIZE MISMATCH!");
    auto output = torch::zeros({h,w}, m.options());
    
    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmal_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [26]:
fname = 'matmul'

In [27]:
def get_sig(fname,src):
    res = re.findall(rf'^(.+\s+{fname}.*?)\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [28]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [29]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [30]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()



In [31]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [32]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [33]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

33.3 ms ± 301 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [43]:
def blk_kernel2D_shrd(f,blocks, threads,sh_sz,*args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0),threads, shared,*args, **kwargs)

In [48]:
def tiled_matmal_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms, ns = shared[:shar_sz],shared[shar_sz:] # dividing the shared mem to ms and ns
    
    for ph in range(cdiv(k,tw)):
        idx = ph*tw # how far starting idx of the next p(output) as we index one tile after another before we reach the coardinates of the p tile
        
        #filliling shared mem
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r = blockIdx.y*blockDim.y +tr # coordinate location within the tile 
                c = blockIdx.x*blockDim.y +tc
                
                ms[tr*tw+tc] = m[tc+idx+ r*k] if r<h and idx+tc<k else 0.
                ns[tr*tw+tc] = n[(tr+idx)*w+ c] if c<w and idx+tr<k else 0.
                
            #do dot product
            for tr in range(blockDim.y):
                for tc in range(blockDim.x):
                    r,c = blockIdx.y*blockDim.y +tr, blockIdx.x*blockDim.y +tc
                    for i in range(tw):
                        if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [56]:
def matmul_2d(m,n, tw=16):
    h,k = m.shape
    k2,w = n.shape
    assert k==k2, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2D_shrd(tiled_matmal_blk, blocks, tpb, tw*tw*2,
                m.flatten(), n.flatten(), output.flatten(),
                h,w,k, tw=tw)
    return output

In [54]:
m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [64]:
torch.isclose(matmul_2d(m1s,m2s,tw=16), m1s@m2s).all()

tensor(False)

In [60]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0,i1, *args, **kwargs)

In [62]:
def tiled_matmul2D_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]
    
    def get_rc(tr,tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
    
    def filled_shareMem(tr, tc, ph):
        r,c= get_rc(tr,tc)
        ms[tr*tw+tc] = m[tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w + c] if c<w and (ph*tw +tr)<k else 0.
        
    def dotprod_tld(tr,tc):
        r,c = get_rc(tr,tc)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] = ms[tr*tw+i] * ns[tw*i+tc]
            
    for ph in range(int(math.ceil(k/tw))):
        run_threads(filled_shareMem,blockDim, ph)
        run_threads(dotprod_tld, blockDim)

In [63]:
def matmul_2d(m,n, tw=16):
    h,k = m.shape
    k2,w = n.shape
    assert k==k2, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2D_shrd(tiled_matmul2D_blk, blocks, tpb, tw*tw*2,
                m.flatten(), n.flatten(), output.flatten(),
                h,w,k, tw=tw)
    return output

In [65]:
torch.isclose(matmul_2d(m1s,m2s,tw=16), m1s@m2s).all()

tensor(False)