# Tiled Kernel
The notebook follows the video from GPU Mode:

[Lecture 5: Going Further with CUDA for Python Programmers](https://www.youtube.com/watch?v=wVsR-YhaHlM)

In [None]:
!pip install ninja

In [None]:
!nvidia-smi

In [None]:
def in_colab() -> bool:
    try:
        import google.colab
        return True
    except ImportError:
        return False


IN_COLAB = in_colab()

In [1]:
import numpy as np

In [89]:
np.random.seed(20250811)

a = np.random.rand(16, 8)
b = np.random.rand(8, 8)
ab = a @ b

In [90]:
from dataclasses import dataclass

@dataclass
class Dim3:
    x: int = 0
    y: int = 0
    z: int = 0

# Dim3 = namedtuple("Dim3", ["x", "y", "z"])

In [174]:
from math import ceil
TW = 4
TS = 4
block_size = Dim3(
    x=ceil(b.shape[1] / TS),
    y=ceil(a.shape[0] / TS),
)
thread_size = Dim3(x=TS, y=TS)


def tiled_kernel(block_dim: Dim3, thread_dim: Dim3, thread_size: Dim3, signalb, m, n, res, r, c, k, mem):
    bx = block_dim.x * thread_size.x
    tx = thread_dim.x
    cc = bx + tx
    by = block_dim.y * thread_size.y
    ty = thread_dim.y
    cr = by + ty
    if cr >= r or cc >= c:
        return None

    # print(f"{(bx, by, tx, ty) = }")

    mm = mem[:TW*TW]
    mn = mem[TW*TW:]
    for ph in range(ceil(k / TW)):
        # copy to memory
        for tw in range(TW):
            mm[ty * TW + tw] = m[ph * TW + tw + cr * k]
            mn[tx * TW + tw] = n[(ph *  TW + tw) * c + cc]

        # s = signalb.wait()
        # if s == TW * TW - 1:
        #     print(f"{(ph, bx, by, tx, ty) = }")
        #     print("mm", mm, "mn", mn, sep="\n")
        signalb.wait()
        # matmul
        for i in range(TW):
            res[cr * k + cc] += mm[ty * TW + i] * mn[tx * TW + i]
        signalb.wait()
    return None

### Debug

In [149]:
mem = np.zeros((2 * TW, TW)).ravel()

mm = mem[:TW*TW]
mn = mem[TW*TW:]

In [159]:
m = a.ravel()
n = b.ravel()
r, k = a.shape
c = b.shape[1]
res = np.zeros((r, c)).ravel()


bd = Dim3(x=0, y=0)
td = Dim3(x=3, y=0)
ts = Dim3(x=TS, y=TS)
by = bd.y * ts.y
ty = td.y
cr = by + ty
bx = bd.x * ts.x
tx = td.x
cc = bx + tx
ph = 0
tw = 0
for _y in range(TS):
    for _x in range(TS):
        td = Dim3(x=_x, y=_y)
        ty = td.y
        cr = by + ty
        tx = td.x
        cc = bx + tx
        for tw in range(TW):
            mm[ty * TW + tw] = m[ph * TW + tw + cr * k]
            mn[tx * TW + tw] = n[(ph *  TW + tw) * c + cc]
for _y in range(TS):
    for _x in range(TS):
        td = Dim3(x=_x, y=_y)
        ty = td.y
        cr = by + ty
        tx = td.x
        cc = bx + tx
        for tw in range(TW):
            res[cr * k + cc] += mm[ty * TW + tw] * mn[tx * TW + tw]


In [161]:
mm.reshape(TW, TW) @ mn.reshape(TW, TW).T

array([[1.06554023, 1.11686943, 1.63775982, 0.6493336 ],
       [1.10830571, 0.80540097, 1.42044116, 0.4881563 ],
       [1.19189501, 1.26205455, 1.44170766, 0.57801085],
       [0.8931771 , 0.41487311, 0.93973794, 0.31899231]])

In [162]:
res.reshape(r, c)[:TW, :TW]

array([[1.06554023, 1.11686943, 1.63775982, 0.6493336 ],
       [1.10830571, 0.80540097, 1.42044116, 0.4881563 ],
       [1.19189501, 1.26205455, 1.44170766, 0.57801085],
       [0.8931771 , 0.41487311, 0.93973794, 0.31899231]])

In [154]:
mm.reshape(TW, TW)

array([[0.95283869, 0.19909693, 0.61565322, 0.45036136],
       [0.71062673, 0.22816761, 0.77238406, 0.09294921],
       [0.20585563, 0.93649271, 0.38760235, 0.48786585],
       [0.21317821, 0.15344169, 0.75795885, 0.04850594]])

In [153]:
mn.reshape(TW, TW)

array([[0.27989922, 0.73731592, 0.93999189, 0.16283601],
       [0.60145317, 0.82840753, 0.17167709, 0.60652322],
       [0.78299314, 0.75942617, 0.83356969, 0.50471555],
       [0.26046591, 0.23095293, 0.27438072, 0.41354838]])

In [155]:
a[:TW, :TW]

array([[0.95283869, 0.19909693, 0.61565322, 0.45036136],
       [0.71062673, 0.22816761, 0.77238406, 0.09294921],
       [0.20585563, 0.93649271, 0.38760235, 0.48786585],
       [0.21317821, 0.15344169, 0.75795885, 0.04850594]])

In [156]:
b[:TW, :TW]

array([[0.27989922, 0.60145317, 0.78299314, 0.26046591],
       [0.73731592, 0.82840753, 0.75942617, 0.23095293],
       [0.93999189, 0.17167709, 0.83356969, 0.27438072],
       [0.16283601, 0.60652322, 0.50471555, 0.41354838]])

In [145]:
a[cr, ph * TW:ph * TW + TW]

array([0.99740093, 0.37985005, 0.32674019, 0.38896543])

In [146]:
b[ph * TW: ph * TW + TW, cc]

array([0.10424338, 0.629821  , 0.29719875, 0.98563026])

## Example with barrier semaphore

In [92]:
import string
import time

from itertools import cycle, islice
from concurrent.futures import ThreadPoolExecutor
from threading import Barrier

def print_thread_id(letter, buffer, signalb):
    idx = signalb.wait()
    buffer[idx] = letter
    signalb.wait()
    if idx == 0:
        print("".join(buffer))
    time.sleep(1)
    """Prints the ID of the current thread."""
    # print(f"Thread ID: {threading.get_ident()}")

# Create a ThreadPoolExecutor with 10 threads
_nw = 16
with ThreadPoolExecutor(max_workers=_nw) as executor:
    # Submit the print_thread_id function to the executor 10 times
    signalb = Barrier(_nw)
    buffer = _nw * [""]
    for letter in islice(cycle(string.ascii_letters), 10 * _nw):
        executor.submit(print_thread_id, letter, buffer, signalb)

abcdefghijklmnop
qrstuvwxyzABCDEF
GHIJKLMNOPQRSTUV
WXYZabcdefghijkl
mnopqrstuvwxyzAB
CDEFGHIJKLMNOPQR
STUVWXYZabcdefgh
ijklmnopqrstuvwx
yzABCDEFGHIJKLMN
OPQRSTUVWXYZabcd


## Runner

In [176]:
def threaded_2d_runner(block_size, thread_size, fn, *args):
    mw = thread_size.x  * thread_size.y
    # print(mw)
    for by in range(block_size.y):
        for bx in range(block_size.x):
            with ThreadPoolExecutor(max_workers=mw) as executor:
                singalb = Barrier(mw)
                for ty in range(thread_size.y):
                    for tx in range(thread_size.x):
                        executor.submit(fn, Dim3(x=bx, y=by), Dim3(x=tx, y=ty), thread_size, signalb, *args)

## Run Test

In [178]:
ar, ak = a.shape
bk, bc = b.shape
res_ab = np.zeros((ar, bc))
mem = np.zeros((2 * TW, TW))

In [179]:
threaded_2d_runner(block_size, thread_size, tiled_kernel, a.flatten(), b.flatten(), res_ab.ravel(), ar, bc, ak, mem.ravel())

In [180]:
assert np.allclose(ab, res_ab)

## Translate to CUDA

We will follow some parts from the previous example

## Numba

Numba python library allows for cuda kernel development in python with some extra tools.