In [6]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32

"""def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    z = torch.empty_like(x, device='cuda')
    add_kernel[(triton.cdiv(x.shape[0], 1024),)](x, z, x.shape[0], 1024)
    return z"""

def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + pid * B0 + range, mask=range < N0)
    z = x + 10
    tl.store(z_ptr + pid * B0 + range, z, mask=range < N0)

x = torch.zeros((10), device='cuda')
for i in range(10):
    x[i] = i
add_spec(x)

tensor([10., 11., 12., 13., 14., 15., 16., 17., 18., 19.], device='cuda:0')

In [9]:
test(add_kernel, add_spec, nelem={"N0": 32})

x: jaxtyping.Float32[Tensor, '32']
x_ptr ([32], <Parameter "x: jaxtyping.Float32[Tensor, '32']">)
z_ptr ([32], None)
Results match: True
Correct!


In [138]:
import inspect

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32

    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)

    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v, device='cuda') - 0.5)
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))

    for k, v in args.items():
        print(k, v)
    tt_args[-1] = torch.zeros(tt_args[-1].shape, device='cuda')
    puzzle[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)

    print(z.shape, z_.shape)
    
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    failures = False
    if not match or failures:
        print("Invalid Access:", failures)
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
        return z, z_, torch.isclose(z, z_)
    # PUPPIES!
    from IPython.display import HTML
    import random
    print("Correct!")
    pups = [
    "2m78jPG",
    "p, n1e9TO",
    "MQCIwzT",
    "udLK6FS",
    "ZNem5o3",
    "DS2IZ6K",
    "aydRUz8",
    "MVUdQYK",
    "kLvno0p",
    "wScLiVz",
    "Z0TII8i",
    "F1SChho",
    "9hRi2jN",
    "lvzRF3W",
    "fqHxOGI",
    "1xeUYme",
    "6tVqKyM",
    "CCxZ6Wr",
    "lMW0OPQ",
    "wHVpHVG",
    "Wj2PGRl",
    "HlaTE8H",
    "k5jALH0",
    "3V37Hqr",
    "Eq2uMTA",
    "Vy9JShx",
    "g9I2ZmK",
    "Nu4RH7f",
    "sWp0Dqd",
    "bRKfspn",
    "qawCMl5",
    "2F6j2B4",
    "fiJxCVA",
    "pCAIlxD",
    "zJx2skh",
    "2Gdl1u7",
    "aJJAY4c",
    "ros6RLC",
    "DKLBJh7",
    "eyxH0Wc",
    "rJEkEw4"]
    return HTML("""
    <video alt="test" controls autoplay=1>
        <source src="https://openpuppies.com/mp4/%s.mp4"  type="video/mp4">
    </video>
    """%(random.sample(pups, 1)[0]))

In [63]:
i_range = torch.arange(0, 8)[:, None]
i_range

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])

In [11]:
j_range = torch.arange(0, 4)[None, :]
j_range

tensor([[0, 1, 2, 3]])

In [12]:
i_range * 4 + j_range

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]])

In [13]:
def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    x = tl.load(x_ptr + tl.arange(0, B0))
    y = tl.load(y_ptr + tl.arange(0, B1))

    z = x[None, :] + y[:, None]

    tl.store(z_ptr + tl.arange(0, B1)[:, None] * B0 + tl.arange(0, B0)[None, :], z)

test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

x: jaxtyping.Float32[Tensor, '32']
y: jaxtyping.Float32[Tensor, '32']
x_ptr ([32], <Parameter "x: jaxtyping.Float32[Tensor, '32']">)
y_ptr ([32], <Parameter "y: jaxtyping.Float32[Tensor, '32']">)
z_ptr ([32, 32], None)
Results match: True
Correct!


In [14]:
torch.arange(0, 2)[:, None] * 4 + torch.arange(0, 4)[None, :]

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [15]:
x = torch.arange(4, device='cuda')
y = torch.arange(2, device='cuda')
print(x)
print(y)
add_vec_spec(x, y)

tensor([0, 1, 2, 3], device='cuda:0')
tensor([0, 1], device='cuda:0')


tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]], device='cuda:0')

In [16]:
z = torch.zeros((2,4), device='cuda')
add_vec_kernel[(1,1,1,)](x, y, z, 4,2,4,2)
z

tensor([[0., 1., 2., 3.],
        [1., 2., 3., 4.]], device='cuda:0')

## Puzzle 4: Outer Vector Add Block

In [167]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32

#100, 90
n = 64
m = 64
def add_vec_block_spec(x: Float32[Tensor, str(n)], y: Float32[Tensor, str(m)]) -> Float32[Tensor, str(m) + " " + str(n)]:
    return x[None, :] + y[:, None]

import os
os.environ["TRITON_INTERPRET"] = "1"

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    
    n_x_ptr = x_ptr + pid_0 * B0
    n_y_ptr = y_ptr + pid_1 * B1
    n_z_ptr = z_ptr + pid_1 * B1 * N0 + pid_0 * B0 * 1

    x_rg = tl.arange(0, B0)
    y_rg = tl.arange(0, B1)
    x = tl.load(x_ptr + B0 * pid_0 + x_rg, mask=B0 * pid_0 + x_rg < N0)
    y = tl.load(y_ptr + B1 * pid_1 + y_rg, mask=B1 * pid_1 + y_rg < N1)

    z = x[None, :] + y[:, None]

    i_rng = tl.arange(0, B1)[:, None]
    j_rng = tl.arange(0, B0)[None, :]

    i_mask_rng = (tl.arange(0, B1) + B1 * pid_1)[:, None]
    j_mask_rng = (tl.arange(0, B0) + B0 * pid_0)[None, :]

    #nums = tl.broadcast_to((tl.arange(0, 32))[:, None], (32, 32))

    gridblock = i_rng * B0 + j_rng
    
    tl.store(n_z_ptr + gridblock, z)#, mask=(i_mask_rng < N1) & (j_mask_rng < N0))
    
z, z_, allclose = test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": n, "N1": m})#, nelem={"N0": 100, "N1": 90})

#for i in range(m):
#    print(z[i])

x: jaxtyping.Float32[Tensor, '64']
y: jaxtyping.Float32[Tensor, '64']
x_ptr ([64], <Parameter "x: jaxtyping.Float32[Tensor, '64']">)
y_ptr ([64], <Parameter "y: jaxtyping.Float32[Tensor, '64']">)
z_ptr ([64, 64], None)
torch.Size([64, 64]) torch.Size([64, 64])
Results match: False
Invalid Access: False
Yours: tensor([[ 0.3713,  0.4889, -0.0028,  ...,  0.4874,  0.0382,  0.4938],
        [ 0.5404,  0.7213, -0.1596,  ...,  0.1490, -0.0902, -0.0299],
        [-0.2665, -0.1488, -0.6406,  ...,  0.0747, -0.3745,  0.0811],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
Spec: tensor([[ 0.3713,  0.4889, -0.0028,  ...,  0.6522,  0.4130,  0.4733],
        [ 0.1900,  0.3076, -0.1841,  ...,  0.4709,  0.2318,  0.2920],
        [-0.1319, -0.0143, -0.5060,  ...,  0.1490, -0.0902, -0.0299],
        ...,
  

In [137]:
print(z[0])
print(z[32])
print(z[64])

tensor([ 3.7129e-01,  4.8892e-01, -2.8284e-03,  9.1232e-01,  9.1810e-01,
         7.6897e-01,  3.8726e-01,  7.9250e-01,  2.0128e-01,  8.8183e-01,
         9.0578e-02,  4.7464e-02,  3.8146e-01,  9.3231e-01,  1.8154e-01,
         1.6619e-01,  8.6319e-01,  4.1091e-01,  3.2922e-01,  5.1761e-01,
         8.0216e-01,  1.8218e-01,  7.4066e-01,  4.0123e-01,  1.8391e-01,
         6.3280e-01,  1.3760e-01,  3.9723e-01,  9.6491e-01,  6.6867e-01,
         2.1943e-01,  6.7505e-01,  1.9002e-01,  3.0765e-01, -1.8410e-01,
         7.3105e-01,  7.3682e-01,  5.8770e-01,  2.0598e-01,  6.1123e-01,
         2.0013e-02,  7.0056e-01, -9.0694e-02, -1.3381e-01,  2.0019e-01,
         7.5104e-01,  2.7096e-04, -1.5079e-02,  6.8191e-01,  2.2964e-01,
         1.4795e-01,  3.3634e-01,  6.2089e-01,  9.0653e-04,  5.5939e-01,
         2.1996e-01,  2.6426e-03,  4.5152e-01, -4.3669e-02,  2.1596e-01,
         7.8364e-01,  4.8739e-01,  3.8154e-02,  4.9378e-01], device='cuda:0')
tensor([-0.4582, -0.3405, -0.8323,  0.0829,  0

IndexError: index 64 is out of bounds for dimension 0 with size 64

In [159]:
N0 = 64
N1 = 64
B0 = 32
B1 = 32
pid_0 = 1
pid_1 = 1

x_rg = torch.arange(0, B0)
y_rg = torch.arange(0, B1)

B0 * pid_0 + x_rg

tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [160]:
B1 * pid_1 + y_rg

tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [161]:
pid_1 * B1 * N0 + pid_0 * B0 * 1

2080

In [166]:
i_rng = torch.arange(0, B1)[:, None]
j_rng = torch.arange(0, B0)[None, :]

i_rng * B0 + j_rng

tensor([[   0,    1,    2,  ...,   29,   30,   31],
        [  32,   33,   34,  ...,   61,   62,   63],
        [  64,   65,   66,  ...,   93,   94,   95],
        ...,
        [ 928,  929,  930,  ...,  957,  958,  959],
        [ 960,  961,  962,  ...,  989,  990,  991],
        [ 992,  993,  994,  ..., 1021, 1022, 1023]])

In [132]:
for i in range(65, 90):
    print(i, z[i])

IndexError: index 65 is out of bounds for dimension 0 with size 64

In [119]:
B1 = 32
pid_1 = 2
(torch.arange(0, B1) + B1 * pid_1)[:, None] < 90

tensor([[ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False]])

In [None]:
B1 = 32
pid_1 = 2
(torch.arange(0, B1) + B1 * pid_1)[:, None] < 90

In [110]:
torch.broadcast_to((torch.arange(0, 32))[:, None], (32, 32))

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 1,  1,  1,  ...,  1,  1,  1],
        [ 2,  2,  2,  ...,  2,  2,  2],
        ...,
        [29, 29, 29,  ..., 29, 29, 29],
        [30, 30, 30,  ..., 30, 30, 30],
        [31, 31, 31,  ..., 31, 31, 31]])

In [87]:
B0 = 1
B1 = 2
N0 = 4
N1 = 3
for i in range(triton.cdiv(N0, B0)):
    for j in range(triton.cdiv(N1, B1)):
        idx = i * N1 + j * B1
        print(idx, idx // N1, idx % N1)

0 0 0
2 0 2
3 1 0
5 1 2
6 2 0
8 2 2
9 3 0
11 3 2


In [None]:
def add_vec_block_kernel(pid_0, pid_1, x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    n_x_ptr = x_ptr + pid_0 * B0
    n_y_ptr = y_ptr + pid_1 * B1
    n_z_ptr = z_ptr + pid_0 * N1 + pid_1 * B1

    x_rg = torch.arange(0, B0)
    y_rg = torch.arange(0, B1)
    x = torch.load(x_ptr + x_rg)#, mask=x_rg < N0)
    y = torch.load(y_ptr + y_rg)#, mask=y_rg < N1)

    z = x[None, :] + y[:, None]

    i_rng = torch.arange(0, B1)[:, None]
    j_rng = torch.arange(0, B0)[None, :]
    
    tl.store(n_z_ptr + i_rng * B0 + j_rng, z)#, mask=(i_rng < N1) & (j_rng < N0))