# CUDA for DUM DUMS

openai/triton seems to be the new hotness for efficiently running jupyter kernels

I tried reading the paper and barely understood a single word: prefecting, SRAM, linearization, blocks, warps, scheduling blcoks on warps and torps on warps

In [2]:
import torch
import triton
import triton.language as tl

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output
    output = torch.empty_like(x)
    n_elements = output.numel()
    
    # grid denotes the number of kernels that can run in parallel
    # it's a function of the total number of the length of the vector / block size
    # block size is the number of concurrent elements we can handle at once, block because it works in multiple dimensions
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    
    # x, y, output are actually pointers to the tensors, think of it as copying by reference
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to output but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output

In [3]:
from rich import inspect

In [7]:
type(add)

function

This is easy enough and if we run this function through a debubgger we'll know everything we need to know. `numel()` is just a function to get the number of elements. We need an output that's of the same size as the input because that's what vector addition is

Where things get confusing is the `add_kernel[]()` part, so it's a dictionary? That I'm indexing with a `grid` and that `grid` is a `lambda` function that divides the `n_elements` by a block size?

So let's step through this in a debugger by running something. Whoops turns out we can't actually trigger a breakpoint inside of `add_kernel` so what the heck is going on? What is a program? What's a block start? Offset? Masks. So confused :O


In [8]:
@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector
    y_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
                 # NOTE: `constexpr` so it can be used as a shape value
):
    # There are multiple 'program's processing different data. We identify which program
    # we are here
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0
    # This program will process inputs that are offset from the initial data.
    # for instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)

In [9]:
type(add_kernel)

triton.code_gen.JITFunction

In [11]:
inspect(add_kernel, methods=True)

Ok so what's  JITFunction?

It seems to have a few interesting properties

Some defaut `arg_names` refering to the inputs `x_ptr, y_ptr`, the `output_ptr`, the number of elements `n_elements` and the number of elements each kernel will operate over concurrently i.e `BLOCK_SIZE`

The function can be parsed

In [12]:
add_kernel.parse()

<ast.Module at 0x7fc5a9e18430>

In [13]:
import ast

print(ast.dump(add_kernel.parse(), indent=4))

Module(
    body=[
        FunctionDef(
            name='add_kernel',
            args=arguments(
                posonlyargs=[],
                args=[
                    arg(arg='x_ptr'),
                    arg(arg='y_ptr'),
                    arg(arg='output_ptr'),
                    arg(arg='n_elements'),
                    arg(
                        arg='BLOCK_SIZE',
                        annotation=Attribute(
                            value=Name(id='tl', ctx=Load()),
                            attr='constexpr',
                            ctx=Load()))],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]),
            body=[
                Assign(
                    targets=[
                        Name(id='pid', ctx=Store())],
                    value=Call(
                        func=Attribute(
                            value=Name(id='tl', ctx=Load()),
                            attr='program_id',
                     

The AST we just printed is likely then optimized by some core engine but we'll get back to this in a second

In [15]:
add_kernel.fn(5,5,5,5,1)

ValueError: Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)

In [16]:
# We still would like to understand what a program ID is so let's try just try

pid = tl.program_id(axis=0)

ValueError: Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)

Well that's kinda sad, we're not going to able to run this program it feels like so let's just try to understand at least what's the type of `pid` by looking at the source code

Just remember that `_builder` in the error, we'll get back to it

In [38]:
from inspect import getsource
import pprint
pp = pprint.PrettyPrinter(indent=4)

pp.pprint(getsource(tl.program_id))

('@builtin\n'
 'def program_id(axis, _builder=None):\n'
 '    """\n'
 '    Returns the id of the current program instance along the given '
 ':code:`axis`.\n'
 '\n'
 '    :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or '
 '2.\n'
 '    :type axis: int\n'
 '    """\n'
 '    # if axis == -1:\n'
 '    #     pid0 = program_id(0, _builder)\n'
 '    #     pid1 = program_id(1, _builder)\n'
 '    #     pid2 = program_id(2, _builder)\n'
 '    #     npg0 = num_programs(0, _builder)\n'
 '    #     npg1 = num_programs(0, _builder)\n'
 '    #     return pid0 + pid1*npg0 + pid2*npg0*npg1\n'
 '    axis = _constexpr_to_value(axis)\n'
 '    return semantic.program_id(axis, _builder)\n')


Ok we already learnt a few things
1. `axis` needs to be one of `0, 1, 2` as it refers to the axes of a 3D grid. Makes sense if we're working with tensors
2. `program_id` will return a `semantic.program_id` and can optionally take in a `_builder parameter`

In our case `axis = 1` so let's go down to that last line of code and figure out what a `semantic.program_id` is

In [41]:
pp.pprint(getsource(tl.semantic.program_id))

('def program_id(axis: int, builder: ir.builder) -> tl.tensor:\n'
 '    return tl.tensor(builder.create_get_program_id(axis), tl.int32)\n')


Interesting..

So the `pid` is actually a `tl.tensor` but not a `torch.Tensor` with a dtype of `int32`

We've also now finally hit some `C` code, the `ir.builder` is imported from

We have two paths to go down now, what's a `tl.tensor` and what's an `ir`


In [43]:
from triton._C.libtriton.triton import ir

In [45]:
inspect(ir, methods=True)

🤔 OK this is getting interesting we have lots of things related to caches and eviction policies. This kinda makes sense. GPUs have a memory hierarchy
* Local
* Shared
* Global

A big part of making GPU computations efficient is about making sure you're cleverly using your cache (or so I'm told)

Ok let's hold off on this part for a while, what's a tensor?

In [49]:
inspect(tl.tensor, methods=True)

Just a single function `to()`, is this the same as the `torch.Tensor.to()` function commonly used to move data to GPU? Also why did `openai/triton` needs its own `tensor` class? Why not just use `torch`?

In [52]:
inspect(tl.tensor, all=True)

In [74]:
# Back to our kernel

# We multiply a `tl.tensor` with a constexpr block size

block_start = pid * BLOCK_SIZE


NameError: name 'pid' is not defined

In [68]:
BLOCK_SIZE = 1024
block_start = 0
offsets = block_start + tl.arange(0, BLOCK_SIZE, _builder= lambda : ())


AttributeError: 'function' object has no attribute 'get_range'

In [72]:
# How do you pass in a custom builder to debug code locally
tl.arange(5, _builder= range(0,5))

TypeError: arange() missing 1 required positional argument: 'end'

In [66]:
pp.pprint(getsource(tl.load))

('@builtin\n'
 'def load(pointer, mask=None, other=None, cache_modifier="", '
 'eviction_policy="", volatile=False, _builder=None):\n'
 '    """\n'
 '    Return a tensor of data whose values are, elementwise, loaded from '
 'memory at location defined by :code:`pointer`.\n'
 '\n'
 '    :code:`mask` and :code:`other` are implicitly broadcast to '
 ':code:`pointer.shape`.\n'
 '\n'
 '    :code:`other` is implicitly typecast to '
 ':code:`pointer.dtype.element_ty`.\n'
 '\n'
 '    :param pointer: Pointers to the data to be loaded.\n'
 '    :type pointer: Block of dtype=triton.PointerDType\n'
 '    :param mask: if mask[idx] is false, do not load the data at address '
 ':code:`pointer[idx]`.\n'
 '    :type mask: Block of triton.int1, optional\n'
 '    :param other: if mask[idx] is false, return other[idx]\n'
 '    :type other: Block, optional\n'
 '    :param cache_modifier: changes cache option in nvidia ptx\n'
 "    'type cache_modifier: str, optional\n"
 '    """\n'
 '    # mask, other can 

In [67]:
pp.pprint(getsource(tl.store))

('@builtin\n'
 'def store(pointer, value, mask=None, _builder=None):\n'
 '    """\n'
 '    Stores :code:`value` tensor of elements in memory, element-wise, at the '
 'memory locations specified by :code:`pointer`.\n'
 '\n'
 '    :code:`value` is implicitly broadcast to :code:`pointer.shape` and '
 'typecast to :code:`pointer.dtype.element_ty`.\n'
 '\n'
 '    :param pointer: The memory locations where the elements of :code:`value` '
 'are stored.\n'
 '    :type pointer: Block of dtype=triton.PointerDType\n'
 '    :param value: The tensor of elements to be stored.\n'
 '    :type value: Block\n'
 '    :param mask: If mask[idx] is false, do not store :code:`value[idx]` at '
 ':code:`pointer[idx]`.\n'
 '    :type mask: Block of triton.int1, optional\n'
 '    """\n'
 '    # value can be constexpr\n'
 '    value = _to_tensor(value, _builder)\n'
 '    if mask is not None:\n'
 '        mask = _to_tensor(mask, _builder)\n'
 '    return semantic.store(pointer, value, mask, _builder)\n')


In [73]:
@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector
    y_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
                 # NOTE: `constexpr` so it can be used as a shape value
):
    # There are multiple 'program's processing different data. We identify which program
    # we are here
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0
    
    print(type(pid))
    
    # This program will process inputs that are offset from the initial data.
    # for instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)