# Tiled Kernel
The notebook follows the video from GPU Mode:

[Lecture 5: Going Further with CUDA for Python Programmers](https://www.youtube.com/watch?v=wVsR-YhaHlM)

In [None]:
!pip install ninja

In [None]:
!nvidia-smi

In [None]:
def in_colab() -> bool:
    try:
        import google.colab
        return True
    except ImportError:
        return False


IN_COLAB = in_colab()

In [None]:
import numpy as np

In [None]:
np.random.seed(20250811)

a = np.random.rand(16, 8)
b = np.random.rand(8, 8)
ab = a @ b

In [None]:
from dataclasses import dataclass

@dataclass
class Dim3:
    x: int = 0
    y: int = 0
    z: int = 0

# Dim3 = namedtuple("Dim3", ["x", "y", "z"])

In [None]:
from math import ceil
TW = 4
TS = 4
block_size = Dim3(
    x=ceil(b.shape[1] / TS),
    y=ceil(a.shape[0] / TS),
)
thread_size = Dim3(x=TS, y=TS)


def tiled_kernel(block_dim: Dim3, thread_dim: Dim3, thread_size: Dim3, signalb, m, n, res, r, c, k, mem):
    bx = block_dim.x * thread_size.x
    tx = thread_dim.x
    cc = bx + tx
    by = block_dim.y * thread_size.y
    ty = thread_dim.y
    cr = by + ty
    if cr >= r or cc >= c:
        return None

    # print(f"{(bx, by, tx, ty) = }")

    mm = mem[:TW*TW]
    mn = mem[TW*TW:]
    for ph in range(ceil(k / TW)):
        # copy to memory
        for tw in range(TW):
            mm[ty * TW + tw] = m[ph * TW + tw + cr * k]
            mn[tx * TW + tw] = n[(ph *  TW + tw) * c + cc]

        # s = signalb.wait()
        # if s == TW * TW - 1:
        #     print(f"{(ph, bx, by, tx, ty) = }")
        #     print("mm", mm, "mn", mn, sep="\n")
        signalb.wait()
        # matmul
        for i in range(TW):
            res[cr * k + cc] += mm[ty * TW + i] * mn[tx * TW + i]
        signalb.wait()
    return None

### Debug

In [None]:
mem = np.zeros((2 * TW, TW)).ravel()

mm = mem[:TW*TW]
mn = mem[TW*TW:]

In [None]:
m = a.ravel()
n = b.ravel()
r, k = a.shape
c = b.shape[1]
res = np.zeros((r, c)).ravel()


bd = Dim3(x=0, y=0)
td = Dim3(x=3, y=0)
ts = Dim3(x=TS, y=TS)
by = bd.y * ts.y
ty = td.y
cr = by + ty
bx = bd.x * ts.x
tx = td.x
cc = bx + tx
ph = 0
tw = 0
for _y in range(TS):
    for _x in range(TS):
        td = Dim3(x=_x, y=_y)
        ty = td.y
        cr = by + ty
        tx = td.x
        cc = bx + tx
        for tw in range(TW):
            mm[ty * TW + tw] = m[ph * TW + tw + cr * k]
            mn[tx * TW + tw] = n[(ph *  TW + tw) * c + cc]
for _y in range(TS):
    for _x in range(TS):
        td = Dim3(x=_x, y=_y)
        ty = td.y
        cr = by + ty
        tx = td.x
        cc = bx + tx
        for tw in range(TW):
            res[cr * k + cc] += mm[ty * TW + tw] * mn[tx * TW + tw]


In [None]:
mm.reshape(TW, TW) @ mn.reshape(TW, TW).T

In [None]:
res.reshape(r, c)[:TW, :TW]

In [None]:
mm.reshape(TW, TW)

In [None]:
mn.reshape(TW, TW)

In [None]:
a[:TW, :TW]

In [None]:
b[:TW, :TW]

In [None]:
a[cr, ph * TW:ph * TW + TW]

In [None]:
b[ph * TW: ph * TW + TW, cc]

## Example with barrier semaphore

In [None]:
import string
import time

from itertools import cycle, islice
from concurrent.futures import ThreadPoolExecutor
from threading import Barrier

def print_thread_id(letter, buffer, signalb):
    idx = signalb.wait()
    buffer[idx] = letter
    signalb.wait()
    if idx == 0:
        print("".join(buffer))
    time.sleep(1)
    """Prints the ID of the current thread."""
    # print(f"Thread ID: {threading.get_ident()}")

# Create a ThreadPoolExecutor with 10 threads
_nw = 16
with ThreadPoolExecutor(max_workers=_nw) as executor:
    # Submit the print_thread_id function to the executor 10 times
    signalb = Barrier(_nw)
    buffer = _nw * [""]
    for letter in islice(cycle(string.ascii_letters), 10 * _nw):
        executor.submit(print_thread_id, letter, buffer, signalb)

## Runner

In [None]:
def threaded_2d_runner(block_size, thread_size, fn, *args):
    mw = thread_size.x  * thread_size.y
    # print(mw)
    for by in range(block_size.y):
        for bx in range(block_size.x):
            with ThreadPoolExecutor(max_workers=mw) as executor:
                singalb = Barrier(mw)
                for ty in range(thread_size.y):
                    for tx in range(thread_size.x):
                        executor.submit(fn, Dim3(x=bx, y=by), Dim3(x=tx, y=ty), thread_size, signalb, *args)

## Run Test

In [None]:
ar, ak = a.shape
bk, bc = b.shape
res_ab = np.zeros((ar, bc))
mem = np.zeros((2 * TW, TW))

In [None]:
threaded_2d_runner(block_size, thread_size, tiled_kernel, a.flatten(), b.flatten(), res_ab.ravel(), ar, bc, ak, mem.ravel())

In [None]:
assert np.allclose(ab, res_ab)

## Translate to CUDA

We will follow some parts from the previous example

## Numba

Numba python library allows for cuda kernel development in python with some extra tools.

In [None]:
from math import ceil

import numpy as np
import torch

from numba import cuda
from numba.cuda import as_cuda_array as ca

In [None]:
@cuda.jit
def matmul_k(m, n, out, tw):
    cbi, cbd, cti = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
    # print(f"{bid = }, {bs = }, {tid = }")
    cr = cbi.y * cbd.y + cti.y
    cc = cbi.x * cbd.x + cti.x
    r, mk = m.shape
    k, c = m.shape
    if cr >= r or cc >= c:
        return None
    shared_mem = cuda.shared.array(0, dtype=np.float32)
    mm = shared_mem[:tw*tw]
    mn = shared_mem[tw*tw:2*tw*tw]
    for ph in range(ceil(k / tw)):
        idx = tw * ph
        mm[tw * cti.y + cti.x] = m[cr, cti.x + idx] if cr < r and idx + cti.x < k else 0.0
        mn[tw * cti.x + cti.y] = n[cti.y + idx, cc] if cc < c and idx + cti.y < k else 0.0
        cuda.syncthreads()
        p = np.float32(0.0)
        for i in range(tw):
            p += mm[cti.y * tw + i] * mn[cti.x * tw + i]
        cuda.syncthreads()
    if cr < r and cc < c:
        out[cr, cc] = p
    return None


def matmul(m, n, tw=16):
    r, mk = m.shape
    k, c = n.shape
    assert mk == k, "Incompatible dimensions"
    out = torch.zeros((r, c), dtype=m.dtype, device=m.device)
    block_size = (ceil(c / tw), ceil(r / tw))
    thread_size = (tw, tw)
    dynamic_shared_mem_size = 2 * tw * tw
    matmul_k[block_size, thread_size, 0, dynamic_shared_mem_size](ca(m), ca(n), ca(out), tw)
    return out

In [None]:
a = torch.randn(1024, 64)
b = torch.randn(64, 128)
ab = a @ b

In [None]:
nab = matmul(a.cuda(), b.cuda())