# Intro to Triton: softmax

## Summary
In this example, we will look at softmax, written in three different ways:
- built-in Pytorch function
- Pytorch implementation, optimized by Torch JIT
- Triton implementation

And, we will compare the performance on a Nvidia V100 GPU.

Here is the calculation we will be implementing. For the 2d X, compute softmax on each row. (Row index denoted "r" and column index denoted "c")

For each r, compute:
$$rowmax(X_r) = max_c X_r$$
$$softmax(X_r) = \frac{exp(X_r - rowmax(X_r))}{\sum_c exp(X_r - rowmax(X_r))}$$ 

## Pytorch built-in

In [1]:
import torch

torch.manual_seed(0)
x = torch.randn((4, 6))
print(x)

torch_softmax_result = torch.softmax(x, dim=1)
print(torch_softmax_result)

tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920],
        [-0.3160, -2.1152,  0.4681, -0.1577,  1.4437,  0.2660],
        [ 0.1665,  0.8744, -0.1435, -0.1116,  0.9318,  1.2590],
        [ 2.0050,  0.0537,  0.6181, -0.4128, -0.8411, -2.3160]])
tensor([[0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121],
        [0.0825, 0.0136, 0.1806, 0.0966, 0.4791, 0.1476],
        [0.1036, 0.2103, 0.0760, 0.0785, 0.2227, 0.3089],
        [0.6442, 0.0915, 0.1609, 0.0574, 0.0374, 0.0086]])


## Pytorch implementation, optimized by Torch JIT

In [2]:
@torch.jit.script
def torch_jit_softmax(x):
    # for each row, take the max
    # shape: R
    x_max = x.max(dim=1)[0]

    # for each row, subtract the row's max from each element.
    # (broadcast x_max to all the columns)
    # shape: R x C
    z = x - x_max[:, None]

    # for each row, complete the numerator
    # shape: R x C
    numerator = torch.exp(z)
    
    # for each row, get the sum of the numerator
    # shape: R
    denominator = numerator.sum(dim=1)
    
    # for each row, complete the softmax
    # (broadcast denominator to all the columns)
    ret = numerator / denominator[:, None]
    
    return ret

torch_jit_softmax_result = torch_jit_softmax(x)
print(torch_jit_softmax_result)
assert torch.allclose(torch_softmax_result, torch_jit_softmax_result)

tensor([[0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121],
        [0.0825, 0.0136, 0.1806, 0.0966, 0.4791, 0.1476],
        [0.1036, 0.2103, 0.0760, 0.0785, 0.2227, 0.3089],
        [0.6442, 0.0915, 0.1609, 0.0574, 0.0374, 0.0086]])


## Triton implementation
We are going to write a softmax in Triton step-by-step. In each step, we'll have a working Triton program. Each version will introduce a new concept.
Here are the steps we'll take:
1. Compute softmax on a single row
2. Compute softmax on all the rows, with each row computed in parallel
3. Generalize so that it works on input whose shape(1) is not limited to powers of two

## Version 1: Compute softmax on a single row

Here is one way to compute the softmax using Triton, with the caveat that it computes softmax on just the first row.

In [3]:
import triton
import triton.language as tl

# this annotation marks the function as a Triton program
@triton.jit(interpret=True) 
def softmax_kernel_v1(
    output, # pointer to the Pytorch tensor where the result will go 
    input,  # pointer to the Pytorch tensor containing the input
    NUM_COLUMNS: tl.constexpr # number of columns in the input and output tensors; limitation: must be a power of 2
):
    # form an array of pointers to the elements of the input tensor
    col_offsets = tl.arange(0, NUM_COLUMNS)       
    input_pointer_array = input + col_offsets

    # load the row from the input tensor into a block
    row = tl.load(input_pointer_array)

    ############################
    # compute softmax on the row
    # Triton uses similar syntax for operations on blocks
    # as is used for numpy arrays and Pytorch tensors
    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_of_row = numerator / denominator
    ############################

    # form an array of pointers to the elements of the output tensor
    output_pointer_array = output + col_offsets

    # store the result block to the output tensor
    tl.store(output_pointer_array, softmax_of_row)

The softmax calculation itself (in the middle) looks similar to the Pytorch version. However, there are some additional details in the rest of the function: pointers, load, and store.

To understand why the additional details are necessary, it is helpful to look at a diagram representing a typical GPU (or CPU or NPU for that matter). 

<img src="img/architecture_load_store1.png" alt="TODO before practice talk" width="1024"/>

The "device memory" is where the Pytorch tensors live. And, there are some number of "processing units" that do the computation. We'll focus on a single processing unit for now. In order for the processing unit to do calculations on the data, it must first "load" data from device memory into its "local memory". And for the result of the calculations to be saved, the processing unit must "store" from local memory into device memory.

Triton has `tl.load` and `tl.store` for doing load and store. In the above program `row` is a "block" of data residing in local memory; a copy of the first row of the input Pytorch tensor.

To tell the system what place in device memory we want to load/store to, we need to provide a "pointer". A Triton pointer refers to the starting location of some data in device memory (i.e., a location within a Pytorch tensor).
A Triton program does not return values directly, rather it stores the result to a Pytorch tensor that has already been created in device memory before the function was called.

One more detail. While the pointers passed into the function really just refer to the start of the Pytorch tensor, `tl.load`/`tl.store` actually take an array of pointers. By "array" we mean a regular, N-dimensional structure, much like a numpy array or Pytorch tensor. To create this array of pointers, we add the indices of the columns. In this case we want all the columns (indices 0 to `NUM_COLUMNS` - 1). Here is an example:

- `input: 400` (representing the starting location of the input tensor living in device memory)
- `NUM_COLUMNS: 6`
- `col_offsets = [0, 1, 2, 3, 4, 5]`
- `input_pointer_array = [400, 401, 402, 403, 404, 405]`
- `row` will be a 1D block of 6 elements, containing a copy of the data at device memory locations 400, 401, ..., 405

To ground this in a more familiar analogy, if this was purely numpy/Pytorch code, and `input` was an array/tensor (rather than a pointer), the code would be `row = input[0][0:NUM_COLUMNS]`.

### How to call a Triton kernel

Since Triton passes results through pointers, calling a Triton kernel involves an additional detail. We need to allocate a Pytorch tensor where the result will go.

To hide this detail (and other details, as we'll see) it is customary to hide the call to a Triton kernel inside a wrapper function.

In [4]:
# Compute softmax of x (limitation: just the first row)
def softmax_v1(x):    
    # Allocate output
    y = torch.empty_like(x)

    n_cols = x.shape[1]

    # Call the kernel and wait for it to finish
    softmax_kernel_v1[(1,)](y, x, NUM_COLUMNS=n_cols)

    return y

softmax_v1_result = softmax_v1(x)
print(softmax_v1_result[0])
assert torch.allclose(torch_softmax_result[0], softmax_v1_result[0])

tensor([0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121])


Ignore the `[(1,)]` for now. It has to do with how Triton introduces parallelism, which we'll look at next!

### Taking a step back: what have we gained so far?

What we are going to see is that one of the key ingredients that makes Triton programs fast is the ability of the compiler to "fuse" sequences of tensor operations together to complete them more efficiently. In general, Pytorch completes each tensor operation one at a time. For example, the max for all the rows is computed before doing any of the subtractions. If the tensor is so large that it doesn't fit in local memory (true for all but the smallest models), this approach is inefficient for the GPU/NPU/CPU. That's because the tensors will be going back and forth between device memory and local memory for every Pytorch operation.

This fundamental issue is one that is addressed in Pytorch with features like "graph mode" and "torch.jit.script". But these optimizers are not always able to produce the most efficient code, particularly for more complex algorithms. In those cases, Triton shines the most. The key is in giving the programmer a bit more control (namely, movement of data between device memory and local memory) to make it easier for Triton to keep data in local memory.

# Version 2: Compute softmax on all the rows, with each row computed in parallel

The function `softmax_v1` only computes softmax on the first row of `x`. To compute softmax for all the rows, we can have Triton call the kernel once for each row.
The tuple in the `[]` is called the "launch grid" and it says how many calls to make. We can use the number of rows (`x.shape[0]`) to say once per row.

```
softmax_kernel_v1[(x.shape[0],)](y, x, NUM_COLUMNS=x.shape[1])
```

The calls will run in parallel, up to the available processing units on your device.

We'll also need to make a modification to the kernel. As is, it would just compute the first row over and over again.

In [5]:
@triton.jit(interpret=True) 
def softmax_kernel_v2(
    output, 
    input,  
    input_row_stride, # number of elements to get to the next row in input tensor
    output_row_stride, # number of elements to get to the next row in output tensor
    NUM_COLUMNS: tl.constexpr 
):
    
    # find out which row we are assigned
    row_index = tl.program_id(0)

    # calculate the pointer to the start of our assigned row
    row_start_pointer = input + row_index * input_row_stride
    
    col_offsets = tl.arange(0, NUM_COLUMNS)
    # use the row_start_pointer now instead of input (which would always be pointing to row 0)
    input_pointer_array = row_start_pointer + col_offsets
    
    row = tl.load(input_pointer_array)

    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    # calculate the pointer to the start of our assigned row
    output_row_start_pointer = output + row_index * output_row_stride
    # use the output_row_start_pointer now instead of input (which would always be pointing to row 0)
    output_pointer_array = output_row_start_pointer + col_offsets
    
    tl.store(output_pointer_array, softmax_output)

### program_id

Triton uses a programming model of Single Program Multiple Data (SPMD). Each call accesses a different subset of the input/output, thereby allowing for parallelism.
It is the programmer's job to say which data belongs to each call. We can ask which call by using `tl.program_id(0)`. In the modified kernel, we
used `tl.program_id(0)` to pick which row we are on.

To seek to the assigned input and output row, we'll need to change the pointer arithmetic to use `row_index`. To do so, we need one more piece of information: stride.

### Stride

To calculate the pointer to the start of a specific row, we need to take into account how far apart elements are in the row dimension. The computer memory is a 1d array, so addresses are 1d. 
Along a specific dimension of a tensor, the distance in memory from one element to the next is the "stride".

`row_start_pointer = input + row_index * input_row_stride`

Here's an example:

```
  tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797],
        [-0.6229, -0.3284, -1.0745, -0.3631, -1.6711,  2.2655],
        [ 0.3117, -0.1842,  1.2866,  1.1820, -0.1271,  1.2169],
        [ 1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973]])
```

* `input: 400` (representing the starting location of the input tensor living in device memory)
* `program_id(0): 2` (which says this call is assigned row 2)
* `NUM_COLUMNS: 6`
* `input_row_stride: 6` (how far to move 1 element along the row dimension)
* `row_index = program_id(0) = 2`
* `row_start_pointer = 400 + 2 * 6 = 412`  
* `col_offsets = [0, 1, 2, 3, 4, 5]`
* `input_pointer_array = [412, 413, 414, 415, 416, 417]`
* `row` will be a 1D block of 6 elements, containing a copy of the data at device memory locations 412, 413, 414, 415, 416, 417

which is `[ 0.3117, -0.1842,  1.2866,  1.1820, -0.1271,  1.2169]`

### Calling the new kernel

Here is how to call the new version of the kernel.

In [6]:
# Compute softmax of x
def softmax_v2(x):    
    # Allocate output
    y = torch.empty_like(x)

    n_rows, n_cols = x.shape

    # Call n_rows copies of the kernel
    # Triton will make sure each has a unique program_id between 0 and n_rows-1
    softmax_kernel_v2[(n_rows,)](y, x,
                                 input_row_stride=x.stride(0),
                                 output_row_stride=y.stride(0),
                                 NUM_COLUMNS=n_cols)

    return y

softmax_v2_result = softmax_v2(x)
print(softmax_v2_result)
assert torch.allclose(torch_softmax_result, softmax_v2_result)

tensor([[0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121],
        [0.0825, 0.0136, 0.1806, 0.0966, 0.4791, 0.1476],
        [0.1036, 0.2103, 0.0760, 0.0785, 0.2227, 0.3089],
        [0.6442, 0.0915, 0.1609, 0.0574, 0.0374, 0.0086]])


#### Launch grid
The tuple given within `[(n_rows,)]` is called the "launch grid". In this example, we are using a 1d launch grid, `(n_rows,)` which means: call n_rows copies of the kernel, with `tl.program_id(0)` being 0, 1, ..., n_rows-1. 

Aside: Launch grid is a tuple rather than an integer because Triton also accepts 2d and 3d launch grids, which is convenient for slicing the data in multiple dimensions.

#### Strides
We also had to pass in arguments for the input and output strides. Pytorch tensors actually have an API for getting the stride of each dimension. In this case we want dim 0, that is, the row dimension.

Technicality: We should be passing in a strides for the column dimension, too, but we are making an assumption that the data is stored row-major. That is that `stride(0) >= n_cols` and `stride(1) == 1`.

## Version 3: Generalize so that it works on input whose shape(1) is not limited to powers of two

Triton actually has a limitation that the loaded/stored block must have dimensions that are a power-of-two. This has to do with the algorithm the GPU compiler currently uses to produce efficient code.
The examples above still work with of `NUM_COLUMNS=6` because we are running these examples in the Triton interpreter, rather than compiling them.

Triton has a concept called "masks" for fitting any shape of data into fixed-sized blocks. This concept is illustrated in the following version of softmax.

In [7]:
@triton.jit(interpret=True) 
def softmax_kernel_v3(
    output, 
    input,  
    input_row_stride, 
    output_row_stride,
    n_cols,                  # the actual number of columns in input and output
    BLOCK_SIZE: tl.constexpr # the desired length of a loaded/stored block
):
    row_index = tl.program_id(0)

    row_start_pointer = input + row_index * input_row_stride
    
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_pointer_array = row_start_pointer + col_offsets

    # do not load elements past the end of the last column
    # For those remaining elements, set the element in the block to negative infinity
    row = tl.load(input_pointer_array, mask=col_offsets < n_cols, other=-float('inf'))

    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    output_row_start_pointer = output + row_index * output_row_stride
    output_pointer_array = output_row_start_pointer + col_offsets

    # only store those elements from the block that are within bounds of the output tensor
    tl.store(output_pointer_array, softmax_output)

The mask argument of `tl.load` is a condition over all the elements of the block. For elements where the `mask` is `True`, we'll load from the input tensor. For elements where the `mask` is `False` we fill them with the value given by `other`.

Here's an example:

```
  tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797],
        [-0.6229, -0.3284, -1.0745, -0.3631, -1.6711,  2.2655],
        [ 0.3117, -0.1842,  1.2866,  1.1820, -0.1271,  1.2169],
        [ 1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973]])
```

* `input: 400` (representing the starting location of the input tensor living in device memory)
* `program_id(0): 2` (which says this call is assigned row 2)
* `n_cols: 6`
* `BLOCK_SIZE: 8`
* `input_row_stride: 6` (how far to move 1 element along the row dimension)
* `row_index = program_id(0) = 2`
* `row_start_pointer = 400 + 2 * 6 = 418`  
* `col_offsets = [0, 1, 2, 3, 4, 5, 6, 7]`
* `input_pointer_array = [418, 419, 420, 421, 422, 423, 424, 425]`
* `row` will be a 1D block of 8 elements, containing `[0.3117, -0.1842,  1.2866,  1.1820, -0.1271,  1.2169, -inf, -inf]`

In [8]:
### Boundary conditions
more use of padding that is more fundamental

    # We actually need to give "load" the pointers to ALL the data that will go into the block.
    # For softmax, since we are computing 1 row, the "array of pointers" will have the same shape as the 1 row.
    # Triton uses the same syntax for these arrays of pointers as Pytorch/numpy arrays.
    # We want to access all the columns in the row.
    
    # What is BLOCK_SIZE? The size of a block actually needs to be a constant. This helps the compiler
    # in generating efficient code.
    # To mark an argument as a constant, we use the type annotation tl.constexpr.
    # For this softmax, we will want to make sure BLOCK_SIZE >= n_cols. So let BLOCK_SIZE = 256.
    
    # Using our example, col_offsets = [0, 1, 2, ..., 255] and
    # input_ptrs_array = [1504, 1505, 1506, ..., 1759]

# What is BLOCK_SIZE? The size of a block actually needs to be a constant. This helps the compiler
    # in generating efficient code.
    # To mark an argument as a constant, we use the type annotation tl.constexpr.
    # For this softmax, we will want to make sure BLOCK_SIZE >= n_cols. So let BLOCK_SIZE = 256.
    
    # Using our example, col_offsets = [0, 1, 2, ..., 255] and
    # input_ptrs_array = [1504, 1505, 1506, ..., 1759]

# Load the row from device memory into a block in local memory.
    # Since BLOCK_SIZE may be > n_cols, we need Triton to insert padding into the block.
    # "mask" is the condition to do bounds checking (avoid reading from invalid memory locations) and "other" is what value to use
    # as the padding.
    # In our example, columns 0 to 239 will have valid data, and 240 to 255 will be filled with -inf.

SyntaxError: invalid syntax (3209652097.py, line 2)

In [None]:
print(torch_softmax(x))
print(torch_jit_softmax(x))
print(softmax(x))

<img src="https://triton-lang.org/main/_images/sphx_glr_02-fused-softmax_001.png"/>