### set up

In [1]:
import math,os, sys,torch,re, numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

### cuda setup for paperspace
* pip install --disable-pip-version-check --root-user-action=ignore wurlitzer ninja

In [2]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [3]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [4]:
d.x,d.y,d.z

(2, 3, 1)

In [5]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [6]:
sys.path.insert(0, '..')

In [7]:
from utills import show_img,load_cuda, cuda_begin, cdiv

In [8]:
%load_ext wurlitzer

In [9]:
torch.manual_seed(42);



### N.B
- the bug was using randn instead of rand, it somehow has a big diffrence even tho the same data to do the same operation (python matmal with tiled matmal(still using python), some how cuda dont give the same bug it handles this

In [79]:
m1 = torch.rand(5120,256)
m1s = m1[:4]
m2 = torch.rand(256, 5120)
m2s = m2[:,:4]

#### Python 2D

In [11]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0),dim3(j1,j0), threads, *args)

In [12]:
def matmal_bk(blockIdx, threadIdx, blockDim, m,n, out, h,w,k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [13]:
def matmal2d(m,n):
    h,k = m.shape
    k2,w = n.shape
    assert k2==k, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmal_bk, blocks, tpb,
                m.flatten(), n.flatten(), output.flatten(),h,w,k)
    return output

In [14]:
torch.isclose(matmal2d(m1s,m2s), m1s@m2s).all()

tensor(False)

In [15]:
cuda_src = cuda_begin + r'''
__global__ void matmal_k(float* m,float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;
    
    if (r>=h || c>=w) return;
    float o = 0;
    for(int i=0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n){
    CHECK_INPUT(m);CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "SIZE MISMATCH!");
    auto output = torch::zeros({h,w}, m.options());
    
    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmal_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [16]:
fname = 'matmul'

In [17]:
def get_sig(fname,src):
    res = re.findall(rf'^(.+\s+{fname}.*?)\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [18]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [19]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [20]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()



In [21]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [22]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [23]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

33.3 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [67]:
def blk_kernel2D_shrd(f,blocks, threads,sh_sz,*args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0),threads, shared,*args, **kwargs)

In [68]:
def tiled_matmal_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms, ns = shared[:shar_sz],shared[shar_sz:] # dividing the shared mem to ms and ns
    
    for ph in range(cdiv(k,tw)):
        idx = ph*tw # how far starting idx of the next p(output) as we index one tile after another before we reach the coardinates of the p tile
        
        #fill shared mem
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r = blockIdx.y*blockDim.y + tr # coordinate location within the tile 
                c = blockIdx.x*blockDim.x + tc
                
                ms[tr*tw+tc] = m[tc+idx+ r*k] if r<h and idx+tc<k else 0.
                ns[tr*tw+tc] = n[(tr+idx)*w+ c] if c<w and idx+tr<k else 0.
                
            #do dot product
            for tr in range(blockDim.y):
                for tc in range(blockDim.x):
                    r,c = blockIdx.y*blockDim.y +tr, blockIdx.x*blockDim.x +tc
                    for i in range(tw):
                        if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [69]:
def matmul_2d(m,n, tw=16):
    h,k = m.shape
    k2,w = n.shape
    assert k==k2, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2D_shrd(matmul_tiled_bk, blocks, tpb, tw*tw*2,
    #blk_kernel2D_shrd(tiled_matmal_blk, blocks, tpb, tw*tw*2,
                m.flatten(), n.flatten(), output.flatten(),
                h,w,k, tw=tw)
    return output

In [70]:
m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [71]:
torch.isclose(matmul_2d_(m1s,m2s,tw=16), m1s@m2s).all()

tensor(True)

### Python threads refactoring

In [88]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0,i1, *args, **kwargs)

In [89]:
def tiled_matmul2D_blk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]
    
    def get_rc(tr,tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
    
    def filled_shareMem(tr, tc, ph):
        r,c= get_rc(tr,tc)
        
        ms[tr*tw+tc] = m[tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w + c] if c<w and (ph*tw +tr)<k else 0.
        
    def dotprod_tld(tr,tc):
        r,c = get_rc(tr,tc)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]
            
    for ph in range(int(math.ceil(k/tw))):
        run_threads(filled_shareMem,blockDim, ph)
        run_threads(dotprod_tld, blockDim)

In [90]:
def matmul_2d(m,n, tw=16):
    h,k = m.shape
    k2,w = n.shape
    assert k==k2, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel2D_shrd(tiled_matmul2D_blk, blocks, tpb, tw*tw*2,
                m.flatten(), n.flatten(), output.flatten(),
                h,w,k, tw=tw)
    return output

In [91]:
torch.isclose(matmul_2d(m1s,m2s,tw=16), m1s@m2s).all()

tensor(True)

### barrier sync with python

In [92]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [93]:
def g(x,sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)
    

In [95]:
num =3
sb = Barrier(num)
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-1
-2
10
30
20


In [114]:
def blk_kernel_shared(f, blocks,tpb, shr_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(shr_sz)
            syncba = Barrier(tpb.y*tpb.x)
            threads = [Thread(target= f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncba, *args),kwargs= kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for tr in threads: tr.start()
            for tr in threads: tr.join()
                                                               

In [115]:
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncba, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]
    
    p = 0.
    for ph in range(cdiv(k,w)):
        ms[tr*tw+tc] = m[tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w + c] if c<w and (ph*tw +tr)<k else 0.
        syncba.wait()
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]
        syncba.wait()
        
        if(r<h and c<w):  out[r*w+c] =p
    

In [116]:
def matmul_2d(m,n, tw=16):
    h,k = m.shape
    k2,w = n.shape
    assert k==k2, 'Size mismatch !'
    output = torch.zeros(h,w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h,tpb.y))
    blk_kernel_shared(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                m.flatten(), n.flatten(), output.flatten(),
                h,w,k, tw=tw)
    return output

In [117]:
torch.isclose(matmul_2d(m1s,m2s,tw=8), m1s@m2s).all()

tensor(True)

In [118]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw){
    int tc = threadIdx.x, int tr = threadIdx.y;
    int r = blockIdx.y*blockDim.y+tr, int c = blockIdx.x*blockDim.x+tc;
    
    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];
    
    float p= 0.0f;
    for(int ph = 0; ph < cdiv(k,tw); ++ph){
        int idx = ph*tw
        ms[tr*tw+tc] = r<h && idx+tc<k? m[tc + idx + r*k]: 0.0f;
        ms[tr*tw + tc] = c<w && idx+tr<k ? m[ (tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0;i<tw;++i) p += ms[tr*tw+i] * ns[tw*i+ tc];
        __syncthreads();
        
    }
    if (r<h && c<w) out[r*w+c] =p;
}
'''

