In [2]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import *

In [146]:
@triton.jit
def matrix_mult_race_cond(A, B, C, M, K, N, BLOCK_SIZE_M : tl.constexpr, BLOCK_SIZE_K : tl.constexpr, BLOCK_SIZE_N : tl.constexpr):
    """
    C := A x B
    """
    pid_m = tl.program_id(axis=0)
    pid_k = tl.program_id(axis=1)
    pid_n = tl.program_id(axis=2)

    # tl.store(C + BLOCK_SIZE_M * pid_m * N + pid_n * BLOCK_SIZE_N, pid_m)

    # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # current block m-wise
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    # current block n-wise
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    # current block k-wise
    offs_k = (pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K

    # get m x k block in A
    a_ptrs = A + (offs_am[:, None] * K + offs_k[None, :] * 1)
    # get k x n block in B
    b_ptrs = B + (offs_k[:, None] * N + offs_bn[None, :] * 1)
    
    # Load the next block of A and B, generate a mask by checking the K dimension.
    # If it is out of bounds, set it to 0.
    # K - pid_k * BLOCK_SIZE_K
    a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0.0)

    c_start = C + BLOCK_SIZE_M * pid_m * N + pid_n * BLOCK_SIZE_N * 1

    this_block = c_start + tl.arange(0, BLOCK_SIZE_M)[:, None] * N + tl.arange(0, BLOCK_SIZE_N)[None, :] * 1

    # We accumulate along the K dimension.
    ans = tl.dot(a, b)
    
    tl.store(this_block, tl.load(this_block) + ans)

In [137]:
M = 64
N = 64
K = 32
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = torch.zeros((M, N), device="cuda", dtype=torch.float16)

In [138]:
from math import ceil

In [139]:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 16
matrix_mult[(ceil(M / BLOCK_SIZE_M), ceil(K / BLOCK_SIZE_K), ceil(N / BLOCK_SIZE_N))]\
    (a, b, c, M, K, N, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N)

<triton.compiler.compiler.CompiledKernel at 0x7fbf1becb470>

In [140]:
c

tensor([[-10.3750,  -7.9922,   4.4023,  ...,   2.0254,   1.2500,   2.3535],
        [  9.4609,   7.5508,  -3.8223,  ...,  -0.6382,   3.8281,  -6.7539],
        [  1.5332,   1.6641,   1.4512,  ...,   3.9746,  -0.7754,  -1.7549],
        ...,
        [ -2.2402,  -0.9053,  -2.7188,  ...,  -1.3975,  -4.0938,   2.5879],
        [ -4.1914,  -2.3789,   3.7031,  ...,   3.0859,  -0.9536,  -0.1411],
        [  3.0781,   1.5752,   2.9355,  ...,   0.9121,   0.5483,   1.2021]],
       device='cuda:0', dtype=torch.float16)

In [141]:
torch.set_printoptions(profile="full")
print(c) # prints the whole tensor
torch.set_printoptions(profile="default") # reset
print(c) # prints the truncated tensor

tensor([[-1.0375e+01, -7.9922e+00,  4.4023e+00, -9.6863e-02, -5.2295e-01,
         -4.2310e-01,  2.0508e+00,  1.7041e-01, -1.1047e+01,  4.2617e+00,
         -9.3906e+00,  5.4150e-01, -3.8892e-01,  2.2148e+00,  6.7810e-02,
         -3.4648e+00, -4.3750e+00,  2.4268e-01,  3.5234e+00,  2.4648e+00,
         -6.3008e+00,  7.6953e+00,  9.5947e-01, -1.8799e+00, -3.6680e+00,
         -7.0312e-01,  2.7773e+00, -4.7485e-01,  1.9092e+00, -1.8086e+00,
          7.1211e+00,  4.0000e+00,  1.0695e+01, -5.5391e+00, -6.5625e-01,
          7.6123e-01,  4.5742e+00,  5.1992e+00,  5.1680e+00,  7.7852e+00,
         -6.8555e+00, -5.3516e+00, -2.6172e+00,  2.4746e+00, -2.3184e+00,
         -5.6211e+00,  2.6523e+00,  1.1855e+00,  6.1797e+00, -8.9355e-01,
         -5.8203e-01,  5.3076e-01,  1.1328e+00, -2.0664e+00, -3.1396e-01,
          3.3242e+00,  7.1143e-01,  3.6602e+00,  1.1539e+01, -3.2695e+00,
         -1.4629e+00,  2.0254e+00,  1.2500e+00,  2.3535e+00],
        [ 9.4609e+00,  7.5508e+00, -3.8223e+00, -1

In [142]:
torch.allclose(c, torch.matmul(a, b))

False

In [143]:
torch.matmul(a, b)[:5] - (c)[:5]

tensor([[  0.5625,   2.4805,  -5.6328,  -8.6328,   4.5430,   1.4785,  -1.8965,
          -0.5020,   8.2344,   0.4062,  -6.4609,  -2.2207,  -2.4844,  -3.2109,
           1.2588,  -5.5273,   0.8906,  -2.1719,   1.9531,   4.3984,   3.3105,
          -4.0234,   0.8892,  -1.1709,   3.3203,   6.6289,   0.4062,  -0.1863,
          -1.6172,  -1.5410,   4.8945,  -1.7676,   5.9297,   2.0977,   5.3594,
           0.8647,   2.3164,  -7.7344,   4.9180,  -2.7031,  -7.6914,  -2.0586,
           1.2891,   3.2363,   0.9893,   5.4336,   1.1445,   6.2812,  -0.4453,
           4.7734,   0.1580,  -9.9297,   4.5586,   1.0879,   3.9238,   4.5195,
          -4.1211,  -0.7070,   6.9297,   3.5117,  -1.3281,  -4.7266,   8.7422,
          -6.5078],
        [  0.5156,   0.4883,   0.4355,   4.4375,   1.6328,   0.0508,  -3.5508,
          -6.3516,   2.0078,   4.2227,   5.4766,   0.9141,  -4.7500,  -1.2734,
          -1.1279,   4.7930,  -3.2207,  -1.2461,   3.7891,   3.2266,  -2.9141,
           3.2285,  -6.2266,  -1

In [144]:
torch.matmul(a, b)

tensor([[-9.8125, -5.5117, -1.2324,  ..., -2.7031,  9.9922, -4.1523],
        [ 9.9766,  8.0391, -3.3867,  ...,  0.3862,  5.9609, -5.4844],
        [ 2.2344,  1.5381, -2.3457,  ..., -4.0234,  5.0820, -2.9922],
        ...,
        [-4.4844,  1.8584, -9.3984,  ..., -5.2344, -7.8867, -0.8086],
        [-1.5791, -6.6406,  1.6533,  ..., 13.1094,  6.0352,  0.3352],
        [-0.9771,  0.1312,  5.2539,  ..., -0.2913,  9.3672,  0.9204]],
       device='cuda:0', dtype=torch.float16)

In [165]:
@triton.jit
def matrix_mult(A, B, C, M, K, N, BLOCK_SIZE_M : tl.constexpr, BLOCK_SIZE_K : tl.constexpr, BLOCK_SIZE_N : tl.constexpr):
    """
    C := A x B
    """
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    # tl.store(C + BLOCK_SIZE_M * pid_m * N + pid_n * BLOCK_SIZE_N, pid_m)

    # current block m-wise
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    # current block n-wise
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    
    # Load the next block of A and B, generate a mask by checking the K dimension.
    # If it is out of bounds, set it to 0.
    # K - pid_k * BLOCK_SIZE_K

    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # We accumulate along the K dimension.
    for pid_k in range(0, (K + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K):
        # current block k-wise
        offs_k = (pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K

        # get m x k block in A
        a_ptrs = A + (offs_am[:, None] * K + offs_k[None, :] * 1)
        # get k x n block in B
        b_ptrs = B + (offs_k[:, None] * N + offs_bn[None, :] * 1)
        
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0.0)
        # We accumulate along the K dimension.
        acc += tl.dot(a, b)

    c_start = C + BLOCK_SIZE_M * pid_m * N + pid_n * BLOCK_SIZE_N * 1

    this_block = c_start + tl.arange(0, BLOCK_SIZE_M)[:, None] * N + tl.arange(0, BLOCK_SIZE_N)[None, :] * 1

    acc = acc.to(tl.float16)
    tl.store(this_block, acc)

In [166]:
M = 64
N = 64
K = 32

a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = torch.zeros((M, N), device="cuda", dtype=torch.float16)

BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 16
matrix_mult[(ceil(M / BLOCK_SIZE_M), ceil(N / BLOCK_SIZE_N))]\
    (a, b, c, M, K, N, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N)

torch.allclose(c, torch.matmul(a, b))

False

In [167]:
torch.matmul(a, b)[10:] - (c)[10:]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [168]:
torch.sum(torch.matmul(a, b) - c)

tensor(0.0006, device='cuda:0', dtype=torch.float16)

In [169]:
torch.matmul(a, b)

tensor([[-5.7031, -1.2734, -0.7969,  ...,  1.4121, -2.0078,  0.3867],
        [ 3.4062, -2.0176,  4.2148,  ...,  0.4951, -2.7363,  3.5020],
        [ 3.3340,  1.5645, -3.2871,  ..., -0.1345, -0.8730,  2.8906],
        ...,
        [-3.0469,  1.6611, -2.2480,  ..., -1.0303,  0.9399,  2.1172],
        [-0.9287, -3.5898,  1.2148,  ..., -0.4834, -0.7793,  2.5137],
        [ 7.0547,  5.0469,  1.6035,  ..., -8.1094,  4.8906, -6.0781]],
       device='cuda:0', dtype=torch.float16)