In [1]:
import torch
import triton
import triton.language as tl

This notebook is adapted from the Triton [tutorial](https://triton-lang.org/master/getting-started/tutorials/01-vector-add.html) for vector addition. 

# Vector Addition

In [2]:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y

**Concepts used in `add_kernel`.**

* **program.** In Triton, a *program* is more-or-less an instantiation of a kernel that is executed against a specific selection of data. Programs execute the smallest unit of work in the Triton framework. Accordingly, the number of programs required to execute against kernel against a specific set of inputs will depend on the size of the inputs and the size of the program. For example, vector addition for vectors of length 256 with a program block size of 64 would require `256/64=4` programs. And each program would access a different range of elements in the input vectors, i.e. `[0:64, 64:128, 128:192, 192:256]`.

* **pointer.** A pointer is a variable that stores the memory address of another variable. For example, assume `x=1` and `x_ptr` is a pointer to `x`. In this case, the value of `x_ptr` is the memory of address of `x`. If you retrieve the value at `x_ptr`, then you have *dereferenced* `x_ptr`.

* **DRAM.** Dynamic RAM. The most common but also slowest RAM in GPUs today. GPU code is usually IO bound–a disproportionate amount of time is spent on read/write operations–and optimization often reduces to minimizing DRAM IO.

In [102]:
@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    # """
    # A Triton kernel for vector addition. 

    # Note: Each torch.tensor object is implicitly converted to a pointer to its first element.

    # Parameters
    # ----------
    # x_ptr,
    #     A pointer for the first input vector. 
    # y_ptr,
    #     A pointer for the second input vector. 
    # output_ptr,
    #     A pointer for the output vector.
    # n_elements,
    #     Size of the vector space being added. Vector addition
    #     requires uniform dimensions across across vectors, so 
    #     we only need one representation of vector size.
    # BLOCK_SIZE,
    #     Number of elements each program should process. `constexpr` is used so that
    #     it can be set as a shape value.
    # """
    # There are multiple program's processing different slices of the input data. 
    # We identify which program we are via the `program_id` method, which will return
    # the id of the current program instance along the given axis. Axis, here, refers
    # the axis of the 3D launch grid and, accordingly, it must be in [0, 1, 2]. 
    # As you'll see lat, we use a 1D launch grid, so we specify `axis=0`.
    pid = tl.program_id(axis=0)

    # This program will process inputs that are offset from the input data.
    # `block_start` specifies where the current block starts. For example, 
    # for the first block, `pid=0`, so `block_start=0`. But, for the second
    # block, `block_start = 1 * BLOCK_SIZE = BLOCK_SIZE`.
    block_start = pid * BLOCK_SIZE

    # Here, we create a list of pointers that correspond to the offsets
    # for this program's slice of the data. We start with the beginning of the block
    # and increment through the size of the block. I.e., `offsets` is a list of pointers.
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Now, we create a mask to guard memory operations against out-of-bounds accesses. 
    # This will allow us to mask out portions of vector pointers and only load the range specified
    # by our offsets.
    mask = offsets < n_elements

    # Load x and y from DRAM, masking out any extra elements in the vectors in case the input is
    # not a multiple of the block size. For example, if we have `n_elements=255` and `BLOCK_SIZE=256`,
    # we would mask the final position pointer.
    # 
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    # Write x + y back to DRAM, masking any extra elements in case the output vectors are not 
    # multiples of the block size. Here, if `mask[idx]==False` then `value[idx]` is not 
    # stored at `pointer[idx]`.
    tl.store(output_ptr + offsets, output, mask=mask)

**Concepts used in `add` function.**

* **Launch grid.** In Triton, a launch grid denotes the number of kernel instances that run in parallel. It is analogous to CUDA launch grids. It can be either `Tuple[int]` or `Callable(metaparameters) -> Tuple[int]`. The dimensionality of the launch grid must match the axis specification in your Triton kernel `progam_id` calls. 

* **`triton.cdiv`.** `triton.cdiv` is [defined](https://github.com/openai/triton/blob/9626c8e944d37f4a62ba2901c90bc7a704111fc1/python/triton/utils.py#L6) as: 
    ```
    def cdiv(x, y):
        return (x + y - 1) // y 
    ```

In [91]:
torch.empty(1)

tensor([4.4645e-11])

In [105]:
def add(
    x: torch.Tensor, 
    y: torch.Tensor,
    block_size: int = 1024
):
    """
    A helper function that: 
        (1) Allocates an output tensor
        (2) Enqueues the `add_kernel` with appropriate grid and block sizes.
    
    Parameters
    ----------
    x: 
        A input vector will be added to `y`
    y:
        An input vector that will be added to `x`
    block_size:
        Block size that should be used to execute the kernel.
    """

    # Preallocate the output tensor
    output = torch.empty_like(x)

    # Make sure each tensor is on GPU
    assert x.is_cuda and y.is_cuda and output.is_cuda

    # Define `n_elemements` as the number of elements in our output tensor.
    # Note, this would not be a good idea if our input and output tensors had
    # different numbers of elements ;).
    n_elements = output.numel()

    # Now we specify our launch grid. 

    # Here, we specify our launch grid. In this case, we use a 1D grid 
    # where the size is the number of blocks.
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    # Here, `grid` is a callable that returns a 1D tuple with the required number of blocks (i.e. programs),
    # given a specified `n_elements` and `BLOCK_SIZE`. 
    # For example, if `n_elements=256, meta['BLOCK_SIZE']=256`, then 
    # `grid(n_elements, meta['BLOCK_SIZE'])==(1,)`. However, 
    # if `n_elements=512, meta['BLOCK_SIZE']=256`, then 
    # `grid(n_elements, meta['BLOCK_SIZE'])==(2,)`.


    # NOTE:
    #  - each torch.tensor object is implicitly converted to a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
    #  - don't forget to pass meta-parameters as keywords arguments

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size)

    return output
    

In [107]:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
    f'The maximum difference between torch and triton is '
    f'{torch.max(torch.abs(output_torch - output_triton))}'
)

tensor([1.1864, 0.6511, 0.2439,  ..., 0.2770, 1.7019, 0.2928], device='cuda:0')
tensor([1.1864, 0.6511, 0.2439,  ..., 0.2770, 1.7019, 0.2928], device='cuda:0')
The maximum difference between torch and triton is 0.0


In [108]:
output_triton = add(x, y, block_size=2)

In [109]:
output_triton = add(x, y, block_size=1024)

In [53]:
offsets = torch.arange(0, 10)
offsets < 5

tensor([ True,  True,  True,  True,  True, False, False, False, False, False])

In [3]:
help(triton.cdiv)

Help on function cdiv in module triton.utils:

cdiv(x, y)



In [4]:
help(torch.cdiv)

AttributeError: module 'torch' has no attribute 'cdiv'

In [60]:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

n_elements = 30
BLOCK_SIZE = 10
meta = {'BLOCK_SIZE': BLOCK_SIZE}

grid(meta)

(3,)

In [6]:
grid

<function __main__.<lambda>(meta)>

In [44]:
n_elements = 20
BLOCK_SIZE = 10
meta = {'BLOCK_SIZE': BLOCK_SIZE}

grid(meta)

(2,)

In [19]:
grid({'BLOCK_SIZE':3})

(17,)

In [20]:
3*17

51

In [21]:
50/3

16.666666666666668

In [41]:
def cdiv(x,y):
    print('x + y: ', x + y)
    print('x + y - 1: ', x + y - 1)
    print('y: ', y)

    return (x + y - 1) // y

In [42]:
cdiv(2,11)

x + y:  13
x + y - 1:  12
y:  11


1

5