In [2]:
import triton
import torch
import triton.language as tl
from triton.runtime import driver

In [3]:
DEVICE = triton.runtime.driver.active.get_active_torch_device()
DEVICE

device(type='cuda', index=0)

x: a x m
y: m x n
z: a x n

In [4]:
def matmul(x, y, z, x_row, y_col):
    assert z.shape[0] == x_row and z.shape[1] == y_col
    for row in range(x_row):
        for col in range(y_col):
            z[row, col] = (x[row, :] * y[:, col]).sum()
    
    return z

In [5]:
a, m, n = 3, 3, 3
x = torch.ones((a, m))
y = torch.ones((m, n))
z = torch.zeros((a, n))

%timeit matmul(x, y, z, a, n)

299 μs ± 76.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
x = torch.ones((a, m), device=DEVICE)
y = torch.ones((m, n), device=DEVICE)
z = torch.zeros((a, n), device=DEVICE)

%timeit matmul(x, y, z, a, n)

1.04 ms ± 119 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


why is it 5x slower on the gpu?? BECAUSE for loops run on CPU backend, and then tensors are on GPU. GPU/CPU sync makes it slower.

idea for matmul kernel:
* block size will be (x_row, y_col) which is basically for loop
* so each row and column from x and y will be loaded to compute a single element of z

better idea:
* divide x and y into blocks, and load blocks, that way, we compute maybe 1/4th of z using a single running kernel

In [56]:
@triton.jit
def fused_matmul(x_ptr, y_ptr, z_ptr, x_rows, y_cols, common_dim, block_size: tl.constexpr):
    accumulator = 0.0
    x_row = tl.program_id(0)
    y_col = tl.program_id(1)
    for k in range(0, common_dim, block_size):
        offsets = k + tl.arange(0, block_size)
        # # common_dim elements loaded
        #(M, K)
        x_offset = x_ptr + (x_row * common_dim)

        # # common dim elements loaded
        #(K, N)
        y_offset = y_ptr + y_col

        y_ptrs = y_offset + (offsets * y_cols)

        x_ptrs = x_offset + offsets

        x_load = tl.load(x_ptrs, mask = offsets < common_dim)
        y_load = tl.load(y_ptrs, mask = offsets < common_dim)

        accumulator += tl.sum(x_load * y_load, axis=-1)

    tl.store(z_ptr + (x_row * y_cols + y_col), accumulator) 

In [57]:
def smart_mul(x, y):

    x_rows, common_dim = x.shape
    y_cols = y.shape[1]

    z = torch.zeros((x_rows, y_cols), device=DEVICE)

    block_size = triton.next_power_of_2(common_dim)

    # warps per block
    # different from warpsize which defines how many threads per warp
    num_warps = max(1, block_size//32)


    kernel = fused_matmul.warmup(x, y, z, x_rows, y_cols, common_dim, block_size, num_stages=1, num_warps=num_warps, grid=(1,))

    kernel._init_handles()

    kernel[(x_rows, y_cols, 1)](x, y, z, x_rows, y_cols, common_dim, block_size)

In [61]:
x = torch.ones((a, m), device=DEVICE)
y = torch.ones((m, n), device=DEVICE)
x, y

(tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]], device='cuda:0'),
 tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]], device='cuda:0'))

In [62]:
%timeit smart_mul(x, y)


90.6 μs ± 4.37 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


HOLYYYYY FROM 300 MICROSECS TO LIKE 100, HOLYYY yippeeee

In [60]:
z

tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]], device='cuda:0')