In [34]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32

"""def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    z = torch.empty_like(x, device='cuda')
    add_kernel[(triton.cdiv(x.shape[0], 1024),)](x, z, x.shape[0], 1024)
    return z"""

def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + pid * B0 + range, mask=range < N0)
    z = x + 10
    tl.store(z_ptr + pid * B0 + range, z, mask=range < N0)

x = torch.zeros((10), device='cuda')
for i in range(10):
    x[i] = i
add_spec(x)

tensor([10., 11., 12., 13., 14., 15., 16., 17., 18., 19.], device='cuda:0')

In [35]:
test(add_kernel, add_spec, nelem={"N0": 32})

x: jaxtyping.Float32[Tensor, '32']
Results match: True
Correct!


In [21]:
import inspect

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32

    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)

    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v, device='cuda') - 0.5)
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))

    #for k, v in args.items():
    #    print(k, v)
    puzzle[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    failures = False
    if not match or failures:
        print("Invalid Access:", failures)
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
        return
    # PUPPIES!
    from IPython.display import HTML
    import random
    print("Correct!")
    pups = [
    "2m78jPG",
    "pn1e9TO",
    "MQCIwzT",
    "udLK6FS",
    "ZNem5o3",
    "DS2IZ6K",
    "aydRUz8",
    "MVUdQYK",
    "kLvno0p",
    "wScLiVz",
    "Z0TII8i",
    "F1SChho",
    "9hRi2jN",
    "lvzRF3W",
    "fqHxOGI",
    "1xeUYme",
    "6tVqKyM",
    "CCxZ6Wr",
    "lMW0OPQ",
    "wHVpHVG",
    "Wj2PGRl",
    "HlaTE8H",
    "k5jALH0",
    "3V37Hqr",
    "Eq2uMTA",
    "Vy9JShx",
    "g9I2ZmK",
    "Nu4RH7f",
    "sWp0Dqd",
    "bRKfspn",
    "qawCMl5",
    "2F6j2B4",
    "fiJxCVA",
    "pCAIlxD",
    "zJx2skh",
    "2Gdl1u7",
    "aJJAY4c",
    "ros6RLC",
    "DKLBJh7",
    "eyxH0Wc",
    "rJEkEw4"]
    return HTML("""
    <video alt="test" controls autoplay=1>
        <source src="https://openpuppies.com/mp4/%s.mp4"  type="video/mp4">
    </video>
    """%(random.sample(pups, 1)[0]))

In [15]:
i_range = torch.arange(0, 8)[:, None]
i_range

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])

In [16]:
j_range = torch.arange(0, 4)[None, :]
j_range

tensor([[0, 1, 2, 3]])

In [17]:
i_range * 4 + j_range

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]])

In [54]:
def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    x = tl.load(x_ptr + tl.arange(0, B0))
    y = tl.load(y_ptr + tl.arange(0, B1))

    z = x[None, :] + y[:, None]

    tl.store(z_ptr + tl.arange(0, B1)[:, None] * B0 + tl.arange(0, B0)[None, :], z)

test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

x: jaxtyping.Float32[Tensor, '32']
y: jaxtyping.Float32[Tensor, '32']
Results match: True
Correct!


In [51]:
torch.arange(0, 2)[:, None] * 4 + torch.arange(0, 4)[None, :]

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [41]:
x = torch.arange(4, device='cuda')
y = torch.arange(2, device='cuda')
print(x)
print(y)
add_vec_spec(x, y)

tensor([0, 1, 2, 3], device='cuda:0')
tensor([0, 1], device='cuda:0')


tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]], device='cuda:0')

In [45]:
z = torch.zeros((2,4), device='cuda')
add_vec_kernel[(1,1,1,)](x, y, z, 4,2,4,2)
z

tensor([[0., 1., 2., 3.],
        [4., 0., 0., 0.]], device='cuda:0')

## Puzzle 4: Outer Vector Add Block

In [68]:
def add_vec_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    
    n_x_ptr = x_ptr + pid_0 * B0
    n_y_ptr = y_ptr + pid_1 * B1
    n_z_ptr = z_ptr + pid_0 * N0 * B1 + pid_1 * B1

    x_rg = tl.arange(0, B0)
    y_rg = tl.arange(0, B1)
    x = tl.load(x_ptr + x_rg)#, mask=x_rg < N0)
    y = tl.load(y_ptr + y_rg)#, mask=y_rg < N1)

    z = x[None, :] + y[:, None]

    i_rng = tl.arange(0, B1)[:, None]
    j_rng = tl.arange(0, B0)[None, :]
    
    tl.store(n_z_ptr + i_rng * B0 + j_rng, z)#, mask=(i_rng < N1) & (j_rng < N0))

test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": 100, "N1": 90})

x: jaxtyping.Float32[Tensor, '100']
y: jaxtyping.Float32[Tensor, '90']
Results match: False
Invalid Access: False
Yours: tensor([[ 0.3713,  0.4889, -0.0028,  ..., -0.1533, -0.6451,  0.2701],
        [ 0.2759,  0.1267, -0.2550,  ...,  0.5607,  0.1790,  0.5842],
        [-0.0070,  0.6736, -0.1177,  ...,  0.2689, -0.5223, -0.5654],
        ...,
        [-0.4654, -0.0915, -0.4728,  ...,  0.3232,  0.1355,  0.2032],
        [ 0.2527, -0.4758,  0.3153,  ...,  0.3500, -0.2044,  0.0783],
        [ 0.3800, -0.1854, -0.0115,  ...,  0.2143,  0.0270, -0.3052]],
       device='cuda:0')
Spec: tensor([[ 0.3713,  0.4889, -0.0028,  ...,  0.2145,  0.6114,  0.7605],
        [ 0.1900,  0.3076, -0.1841,  ...,  0.0332,  0.4301,  0.5793],
        [-0.1319, -0.0143, -0.5060,  ..., -0.2887,  0.1082,  0.2573],
        ...,
        [-0.2529, -0.1353, -0.6270,  ..., -0.4097, -0.0128,  0.1364],
        [ 0.2798,  0.3974, -0.0943,  ...,  0.1230,  0.5199,  0.6691],
        [-0.0330,  0.0846, -0.4071,  ..., -0.1898,  

In [70]:
B0 = 1
B1 = 1
for i in range(4):
    for j in range(3):
        print(i * B0 * B1 + j * B0)

0
1
2
1
2
3
2
3
4
3
4
5
