## Setup

In [None]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns

In [None]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [None]:
sys.path.insert(0, '..')

In [None]:
from utils import show_img,load_cuda,cuda_begin

In [None]:
%load_ext wurlitzer

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'
torch.manual_seed(42);

In [None]:
m1 = torch.rand(50_000, 784)
m1s = m1[:8]
m2 = torch.rand(784,10)

## Reminder

### 2d Python kernel

In [None]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(ns(x=i1,y=i0), ns(x=j1,y=j0), threads, *args)

In [None]:
def matmul_bk(blockidx, threadidx, blockdim, m, n, out, h, w, k):
    r = blockidx.y*blockdim.y + threadidx.y
    c = blockidx.x*blockdim.x + threadidx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [None]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = ns(x=16,y=16)
    blocks = ns(x=math.ceil(w/tpb.x), y=math.ceil(h/tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [None]:
torch.isclose(matmul_2d(m1s, m2), m1s@m2).all()

tensor(True)

### CUDA

In [None]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul'

In [None]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [None]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [None]:
module.matmul(m1c,m2c).shape

torch.Size([50000, 10])

In [None]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

## Shared mem

### 2d Python kernel

In [None]:
a = torch.zeros(5)
b,c = a[:3],a[3:]

In [None]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

In [None]:
"""
    float p = 0;
    for (int ph = 0; ph < ceil(w/(float)TW); ++ph) {
        ms[ty][tx] = ((r<w) && (ph*TW+tx)<w) ? M[ tx + ph*TW + r*w] : 0.0f;
        ns[ty][tx] = ((c<w) && (ph*TW+ty)<w) ? N[(ty + ph*TW)*w +c] : 0.0f;
        __syncthreads();
        for (int k = 0; k < TW; ++k) p += ms[ty][k] * ns[k][tx];
        __syncthreads();
    }
    if (r<w && c<w) out[r*w + c] = p;
"""

In [None]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(ns(x=i1,y=i0), threads, shared, *args, **kwargs)

In [None]:
def run_threads(f, blockdim, *args, **kwargs):
    for i0 in range(blockdim.y):
        for i1 in range(blockdim.x): f(i0, i1, *args, **kwargs)

In [None]:
def matmul_tiled_bk(blockidx, blockdim, shared, m, n, out, h, w, k, tw, mm):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    def get_rc(ty, tx): return blockidx.y*blockdim.y + ty, blockidx.x*blockdim.x + tx

    def fill_shared_tk(ty, tx, ph):
        r,c = get_rc(ty, tx)
        ms[ty*tw+tx] = m[ tx + ph*tw + r*k] if r<h and (ph*tw+tx)<k else 0.
        ns[ty*tw+tx] = n[(ty + ph*tw)*w +c] if c<w and (ph*tw+ty)<k else 0.

    def dotprod_tk(ty, tx):
        r,c = get_rc(ty, tx)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += m[r*tw+i] * n[i*tw+c]

    for ph in range(int(math.ceil(k/tw))):
        run_threads(fill_shared_tk, blockdim, ph)
        run_threads(dotprod_tk, blockdim)

In [None]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = ns(x=16,y=16)
    blocks = ns(x=math.ceil(w/tpb.x), y=math.ceil(h/tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, 16*16*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=16, mm=m)
    return output

In [None]:
m1s.shape

torch.Size([8, 784])

In [None]:
torch.isclose(matmul_2d(m1s, m2), m1s@m2).all()

tensor(False)

In [None]:
matmul_2d(m1s, m2)

tensor([[185.04, 211.74, 207.05, 175.68, 184.61, 165.85, 200.46, 180.15, 159.06, 219.29],
        [246.11, 374.62, 354.15, 383.29, 338.41, 357.94, 128.71, 131.15, 122.34, 165.20],
        [249.40, 265.87, 349.05, 316.15, 296.39, 316.05, 182.69, 135.67, 161.97, 164.37],
        [294.56, 383.04, 415.73, 360.61, 368.37, 368.36, 209.56, 154.00, 186.31, 203.71],
        [329.01, 409.81, 457.99, 411.94, 386.65, 459.32, 230.68, 191.58, 222.18, 248.16],
        [331.02, 380.52, 393.87, 374.61, 372.91, 369.44, 196.03, 154.32, 164.01, 203.76],
        [322.04, 330.92, 366.76, 310.33, 362.50, 343.97, 189.47, 133.25, 189.82, 200.85],
        [255.43, 362.17, 379.27, 294.94, 339.76, 320.28, 164.46, 144.77, 157.93, 187.12]])

In [None]:
m1s@m2

tensor([[187.96, 191.44, 184.27, 196.76, 197.56, 195.40, 194.36, 192.61, 184.84, 186.31],
        [206.28, 205.22, 192.29, 210.71, 202.66, 204.08, 200.01, 201.26, 196.18, 199.55],
        [204.15, 210.09, 195.36, 206.41, 215.57, 207.06, 210.58, 207.05, 195.90, 197.03],
        [201.26, 203.92, 189.29, 198.65, 202.47, 198.09, 203.28, 198.94, 192.32, 193.05],
        [190.47, 194.77, 181.35, 190.37, 194.37, 194.06, 191.73, 191.55, 182.49, 185.08],
        [203.07, 205.65, 193.66, 204.50, 205.94, 199.65, 204.20, 198.63, 192.76, 200.65],
        [205.57, 203.16, 191.63, 209.40, 209.39, 201.50, 208.93, 201.71, 195.01, 199.23],
        [195.23, 198.80, 189.03, 197.50, 199.72, 199.72, 198.44, 190.80, 188.84, 197.59]])

### CUDA

In [None]:
cuda_src = cuda_begin + r'''
#define TW 16
__global__ void matmul_k(float* m, float* n, float* out, int w, int w, int k, unsigned md_sz, unsigned nd_sz) {
    extern __shared__ float ms_ns[];
    float *ms = (float *) ms_ns;
    float *ns = (float *) ms_ns + md_sz;

    float p = 0;
    for (int ph = 0; ph < ceil(w/(float)TW); ++ph) {
        if ((r < w) && (ph*TW+tx) < w) ms[ty][tx] = M[r*w + ph*TW + tx];
        else ms[ty][tx] = 0.0f;
        if ((ph*TW+ty) < w && c < w) ns[ty][tx] = N[(ph*TW + ty)*w + c];
        else ns[ty][tx] = 0.0f;
        __syncthreads();
        for (int k = 0; k < TW; ++k) p += ms[ty][k] * ns[k][tx];
        __syncthreads();
    }
    if (r < w && c < w) out[r*w + c] = p;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul'

In [None]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [None]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [None]:
module.matmul(m1c,m2c).shape

torch.Size([50000, 10])

In [None]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')